diff options
Diffstat (limited to 'contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp')
-rw-r--r-- | contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp | 910 |
1 files changed, 450 insertions, 460 deletions
diff --git a/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp b/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp index 24b035d67598..c97244328d37 100644 --- a/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp +++ b/contrib/llvm/tools/clang/lib/CodeGen/CodeGenPGO.cpp @@ -58,12 +58,16 @@ void CodeGenPGO::setFuncName(llvm::Function *Fn) { } void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) { - // Usually, we want to match the function's linkage, but - // available_externally and extern_weak both have the wrong semantics. + // We generally want to match the function's linkage, but available_externally + // and extern_weak both have the wrong semantics, and anything that doesn't + // need to link across compilation units doesn't need to be visible at all. if (Linkage == llvm::GlobalValue::ExternalWeakLinkage) Linkage = llvm::GlobalValue::LinkOnceAnyLinkage; else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage) Linkage = llvm::GlobalValue::LinkOnceODRLinkage; + else if (Linkage == llvm::GlobalValue::InternalLinkage || + Linkage == llvm::GlobalValue::ExternalLinkage) + Linkage = llvm::GlobalValue::PrivateLinkage; auto *Value = llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false); @@ -138,482 +142,469 @@ const int PGOHash::NumBitsPerType; const unsigned PGOHash::NumTypesPerWord; const unsigned PGOHash::TooBig; - /// A RecursiveASTVisitor that fills a map of statements to PGO counters. - struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { - /// The next counter value to assign. - unsigned NextCounter; - /// The function hash. - PGOHash Hash; - /// The map of statements to counters. - llvm::DenseMap<const Stmt *, unsigned> &CounterMap; - - MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) - : NextCounter(0), CounterMap(CounterMap) {} - - // Blocks and lambdas are handled as separate functions, so we need not - // traverse them in the parent context. - bool TraverseBlockExpr(BlockExpr *BE) { return true; } - bool TraverseLambdaBody(LambdaExpr *LE) { return true; } - bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } - - bool VisitDecl(const Decl *D) { - switch (D->getKind()) { - default: - break; - case Decl::Function: - case Decl::CXXMethod: - case Decl::CXXConstructor: - case Decl::CXXDestructor: - case Decl::CXXConversion: - case Decl::ObjCMethod: - case Decl::Block: - case Decl::Captured: - CounterMap[D->getBody()] = NextCounter++; - break; - } - return true; +/// A RecursiveASTVisitor that fills a map of statements to PGO counters. +struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { + /// The next counter value to assign. + unsigned NextCounter; + /// The function hash. + PGOHash Hash; + /// The map of statements to counters. + llvm::DenseMap<const Stmt *, unsigned> &CounterMap; + + MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) + : NextCounter(0), CounterMap(CounterMap) {} + + // Blocks and lambdas are handled as separate functions, so we need not + // traverse them in the parent context. + bool TraverseBlockExpr(BlockExpr *BE) { return true; } + bool TraverseLambdaBody(LambdaExpr *LE) { return true; } + bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } + + bool VisitDecl(const Decl *D) { + switch (D->getKind()) { + default: + break; + case Decl::Function: + case Decl::CXXMethod: + case Decl::CXXConstructor: + case Decl::CXXDestructor: + case Decl::CXXConversion: + case Decl::ObjCMethod: + case Decl::Block: + case Decl::Captured: + CounterMap[D->getBody()] = NextCounter++; + break; } + return true; + } - bool VisitStmt(const Stmt *S) { - auto Type = getHashType(S); - if (Type == PGOHash::None) - return true; - - CounterMap[S] = NextCounter++; - Hash.combine(Type); + bool VisitStmt(const Stmt *S) { + auto Type = getHashType(S); + if (Type == PGOHash::None) return true; + + CounterMap[S] = NextCounter++; + Hash.combine(Type); + return true; + } + PGOHash::HashType getHashType(const Stmt *S) { + switch (S->getStmtClass()) { + default: + break; + case Stmt::LabelStmtClass: + return PGOHash::LabelStmt; + case Stmt::WhileStmtClass: + return PGOHash::WhileStmt; + case Stmt::DoStmtClass: + return PGOHash::DoStmt; + case Stmt::ForStmtClass: + return PGOHash::ForStmt; + case Stmt::CXXForRangeStmtClass: + return PGOHash::CXXForRangeStmt; + case Stmt::ObjCForCollectionStmtClass: + return PGOHash::ObjCForCollectionStmt; + case Stmt::SwitchStmtClass: + return PGOHash::SwitchStmt; + case Stmt::CaseStmtClass: + return PGOHash::CaseStmt; + case Stmt::DefaultStmtClass: + return PGOHash::DefaultStmt; + case Stmt::IfStmtClass: + return PGOHash::IfStmt; + case Stmt::CXXTryStmtClass: + return PGOHash::CXXTryStmt; + case Stmt::CXXCatchStmtClass: + return PGOHash::CXXCatchStmt; + case Stmt::ConditionalOperatorClass: + return PGOHash::ConditionalOperator; + case Stmt::BinaryConditionalOperatorClass: + return PGOHash::BinaryConditionalOperator; + case Stmt::BinaryOperatorClass: { + const BinaryOperator *BO = cast<BinaryOperator>(S); + if (BO->getOpcode() == BO_LAnd) + return PGOHash::BinaryOperatorLAnd; + if (BO->getOpcode() == BO_LOr) + return PGOHash::BinaryOperatorLOr; + break; } - PGOHash::HashType getHashType(const Stmt *S) { - switch (S->getStmtClass()) { - default: - break; - case Stmt::LabelStmtClass: - return PGOHash::LabelStmt; - case Stmt::WhileStmtClass: - return PGOHash::WhileStmt; - case Stmt::DoStmtClass: - return PGOHash::DoStmt; - case Stmt::ForStmtClass: - return PGOHash::ForStmt; - case Stmt::CXXForRangeStmtClass: - return PGOHash::CXXForRangeStmt; - case Stmt::ObjCForCollectionStmtClass: - return PGOHash::ObjCForCollectionStmt; - case Stmt::SwitchStmtClass: - return PGOHash::SwitchStmt; - case Stmt::CaseStmtClass: - return PGOHash::CaseStmt; - case Stmt::DefaultStmtClass: - return PGOHash::DefaultStmt; - case Stmt::IfStmtClass: - return PGOHash::IfStmt; - case Stmt::CXXTryStmtClass: - return PGOHash::CXXTryStmt; - case Stmt::CXXCatchStmtClass: - return PGOHash::CXXCatchStmt; - case Stmt::ConditionalOperatorClass: - return PGOHash::ConditionalOperator; - case Stmt::BinaryConditionalOperatorClass: - return PGOHash::BinaryConditionalOperator; - case Stmt::BinaryOperatorClass: { - const BinaryOperator *BO = cast<BinaryOperator>(S); - if (BO->getOpcode() == BO_LAnd) - return PGOHash::BinaryOperatorLAnd; - if (BO->getOpcode() == BO_LOr) - return PGOHash::BinaryOperatorLOr; - break; - } - } - return PGOHash::None; } + return PGOHash::None; + } +}; + +/// A StmtVisitor that propagates the raw counts through the AST and +/// records the count at statements where the value may change. +struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { + /// PGO state. + CodeGenPGO &PGO; + + /// A flag that is set when the current count should be recorded on the + /// next statement, such as at the exit of a loop. + bool RecordNextStmtCount; + + /// The count at the current location in the traversal. + uint64_t CurrentCount; + + /// The map of statements to count values. + llvm::DenseMap<const Stmt *, uint64_t> &CountMap; + + /// BreakContinueStack - Keep counts of breaks and continues inside loops. + struct BreakContinue { + uint64_t BreakCount; + uint64_t ContinueCount; + BreakContinue() : BreakCount(0), ContinueCount(0) {} }; + SmallVector<BreakContinue, 8> BreakContinueStack; - /// A StmtVisitor that propagates the raw counts through the AST and - /// records the count at statements where the value may change. - struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { - /// PGO state. - CodeGenPGO &PGO; - - /// A flag that is set when the current count should be recorded on the - /// next statement, such as at the exit of a loop. - bool RecordNextStmtCount; - - /// The map of statements to count values. - llvm::DenseMap<const Stmt *, uint64_t> &CountMap; - - /// BreakContinueStack - Keep counts of breaks and continues inside loops. - struct BreakContinue { - uint64_t BreakCount; - uint64_t ContinueCount; - BreakContinue() : BreakCount(0), ContinueCount(0) {} - }; - SmallVector<BreakContinue, 8> BreakContinueStack; - - ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, - CodeGenPGO &PGO) - : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} - - void RecordStmtCount(const Stmt *S) { - if (RecordNextStmtCount) { - CountMap[S] = PGO.getCurrentRegionCount(); - RecordNextStmtCount = false; - } - } + ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, + CodeGenPGO &PGO) + : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} - void VisitStmt(const Stmt *S) { - RecordStmtCount(S); - for (Stmt::const_child_range I = S->children(); I; ++I) { - if (*I) - this->Visit(*I); - } + void RecordStmtCount(const Stmt *S) { + if (RecordNextStmtCount) { + CountMap[S] = CurrentCount; + RecordNextStmtCount = false; } + } - void VisitFunctionDecl(const FunctionDecl *D) { - // Counter tracks entry to the function body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); - Visit(D->getBody()); - } + /// Set and return the current count. + uint64_t setCount(uint64_t Count) { + CurrentCount = Count; + return Count; + } - // Skip lambda expressions. We visit these as FunctionDecls when we're - // generating them and aren't interested in the body when generating a - // parent context. - void VisitLambdaExpr(const LambdaExpr *LE) {} - - void VisitCapturedDecl(const CapturedDecl *D) { - // Counter tracks entry to the capture body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); - Visit(D->getBody()); + void VisitStmt(const Stmt *S) { + RecordStmtCount(S); + for (Stmt::const_child_range I = S->children(); I; ++I) { + if (*I) + this->Visit(*I); } + } - void VisitObjCMethodDecl(const ObjCMethodDecl *D) { - // Counter tracks entry to the method body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); - Visit(D->getBody()); - } + void VisitFunctionDecl(const FunctionDecl *D) { + // Counter tracks entry to the function body. + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; + Visit(D->getBody()); + } - void VisitBlockDecl(const BlockDecl *D) { - // Counter tracks entry to the block body. - RegionCounter Cnt(PGO, D->getBody()); - Cnt.beginRegion(); - CountMap[D->getBody()] = PGO.getCurrentRegionCount(); - Visit(D->getBody()); - } + // Skip lambda expressions. We visit these as FunctionDecls when we're + // generating them and aren't interested in the body when generating a + // parent context. + void VisitLambdaExpr(const LambdaExpr *LE) {} - void VisitReturnStmt(const ReturnStmt *S) { - RecordStmtCount(S); - if (S->getRetValue()) - Visit(S->getRetValue()); - PGO.setCurrentRegionUnreachable(); - RecordNextStmtCount = true; - } + void VisitCapturedDecl(const CapturedDecl *D) { + // Counter tracks entry to the capture body. + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; + Visit(D->getBody()); + } - void VisitGotoStmt(const GotoStmt *S) { - RecordStmtCount(S); - PGO.setCurrentRegionUnreachable(); - RecordNextStmtCount = true; - } + void VisitObjCMethodDecl(const ObjCMethodDecl *D) { + // Counter tracks entry to the method body. + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; + Visit(D->getBody()); + } - void VisitLabelStmt(const LabelStmt *S) { - RecordNextStmtCount = false; - // Counter tracks the block following the label. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); - Visit(S->getSubStmt()); - } + void VisitBlockDecl(const BlockDecl *D) { + // Counter tracks entry to the block body. + uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); + CountMap[D->getBody()] = BodyCount; + Visit(D->getBody()); + } - void VisitBreakStmt(const BreakStmt *S) { - RecordStmtCount(S); - assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); - BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount(); - PGO.setCurrentRegionUnreachable(); - RecordNextStmtCount = true; - } + void VisitReturnStmt(const ReturnStmt *S) { + RecordStmtCount(S); + if (S->getRetValue()) + Visit(S->getRetValue()); + CurrentCount = 0; + RecordNextStmtCount = true; + } - void VisitContinueStmt(const ContinueStmt *S) { - RecordStmtCount(S); - assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); - BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount(); - PGO.setCurrentRegionUnreachable(); - RecordNextStmtCount = true; - } + void VisitCXXThrowExpr(const CXXThrowExpr *E) { + RecordStmtCount(E); + if (E->getSubExpr()) + Visit(E->getSubExpr()); + CurrentCount = 0; + RecordNextStmtCount = true; + } - void VisitWhileStmt(const WhileStmt *S) { - RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); - BreakContinueStack.push_back(BreakContinue()); - // Visit the body region first so the break/continue adjustments can be - // included when visiting the condition. - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); - Visit(S->getBody()); - Cnt.adjustForControlFlow(); - - // ...then go back and propagate counts through the condition. The count - // at the start of the condition is the sum of the incoming edges, - // the backedge from the end of the loop body, and the edges from - // continue statements. - BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.setCurrentRegionCount(Cnt.getParentCount() + - Cnt.getAdjustedCount() + BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); - Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); - RecordNextStmtCount = true; - } + void VisitGotoStmt(const GotoStmt *S) { + RecordStmtCount(S); + CurrentCount = 0; + RecordNextStmtCount = true; + } - void VisitDoStmt(const DoStmt *S) { - RecordStmtCount(S); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); - BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); - Visit(S->getBody()); - Cnt.adjustForControlFlow(); - - BreakContinue BC = BreakContinueStack.pop_back_val(); - // The count at the start of the condition is equal to the count at the - // end of the body. The adjusted count does not include either the - // fall-through count coming into the loop or the continue count, so add - // both of those separately. This is coincidentally the same equation as - // with while loops but for different reasons. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + - Cnt.getAdjustedCount() + BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); - Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); - RecordNextStmtCount = true; - } + void VisitLabelStmt(const LabelStmt *S) { + RecordNextStmtCount = false; + // Counter tracks the block following the label. + uint64_t BlockCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = BlockCount; + Visit(S->getSubStmt()); + } - void VisitForStmt(const ForStmt *S) { - RecordStmtCount(S); - if (S->getInit()) - Visit(S->getInit()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); - BreakContinueStack.push_back(BreakContinue()); - // Visit the body region first. (This is basically the same as a while - // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); - Visit(S->getBody()); - Cnt.adjustForControlFlow(); - - // The increment is essentially part of the body but it needs to include - // the count for all the continue statements. - if (S->getInc()) { - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); - Visit(S->getInc()); - Cnt.adjustForControlFlow(); - } - - BreakContinue BC = BreakContinueStack.pop_back_val(); - - // ...then go back and propagate counts through the condition. - if (S->getCond()) { - Cnt.setCurrentRegionCount(Cnt.getParentCount() + - Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); - Visit(S->getCond()); - Cnt.adjustForControlFlow(); - } - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); - RecordNextStmtCount = true; - } + void VisitBreakStmt(const BreakStmt *S) { + RecordStmtCount(S); + assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); + BreakContinueStack.back().BreakCount += CurrentCount; + CurrentCount = 0; + RecordNextStmtCount = true; + } - void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { - RecordStmtCount(S); - Visit(S->getRangeStmt()); - Visit(S->getBeginEndStmt()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); - BreakContinueStack.push_back(BreakContinue()); - // Visit the body region first. (This is basically the same as a while - // loop; see further comments in VisitWhileStmt.) - Cnt.beginRegion(); - CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount(); - Visit(S->getLoopVarStmt()); - Visit(S->getBody()); - Cnt.adjustForControlFlow(); - - // The increment is essentially part of the body but it needs to include - // the count for all the continue statements. - Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() + - BreakContinueStack.back().ContinueCount); - CountMap[S->getInc()] = PGO.getCurrentRegionCount(); - Visit(S->getInc()); - Cnt.adjustForControlFlow(); + void VisitContinueStmt(const ContinueStmt *S) { + RecordStmtCount(S); + assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); + BreakContinueStack.back().ContinueCount += CurrentCount; + CurrentCount = 0; + RecordNextStmtCount = true; + } - BreakContinue BC = BreakContinueStack.pop_back_val(); + void VisitWhileStmt(const WhileStmt *S) { + RecordStmtCount(S); + uint64_t ParentCount = CurrentCount; + + BreakContinueStack.push_back(BreakContinue()); + // Visit the body region first so the break/continue adjustments can be + // included when visiting the condition. + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = CurrentCount; + Visit(S->getBody()); + uint64_t BackedgeCount = CurrentCount; + + // ...then go back and propagate counts through the condition. The count + // at the start of the condition is the sum of the incoming edges, + // the backedge from the end of the loop body, and the edges from + // continue statements. + BreakContinue BC = BreakContinueStack.pop_back_val(); + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; + Visit(S->getCond()); + setCount(BC.BreakCount + CondCount - BodyCount); + RecordNextStmtCount = true; + } - // ...then go back and propagate counts through the condition. - Cnt.setCurrentRegionCount(Cnt.getParentCount() + - Cnt.getAdjustedCount() + - BC.ContinueCount); - CountMap[S->getCond()] = PGO.getCurrentRegionCount(); - Visit(S->getCond()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); - RecordNextStmtCount = true; - } + void VisitDoStmt(const DoStmt *S) { + RecordStmtCount(S); + uint64_t LoopCount = PGO.getRegionCount(S); + + BreakContinueStack.push_back(BreakContinue()); + // The count doesn't include the fallthrough from the parent scope. Add it. + uint64_t BodyCount = setCount(LoopCount + CurrentCount); + CountMap[S->getBody()] = BodyCount; + Visit(S->getBody()); + uint64_t BackedgeCount = CurrentCount; + + BreakContinue BC = BreakContinueStack.pop_back_val(); + // The count at the start of the condition is equal to the count at the + // end of the body, plus any continues. + uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; + Visit(S->getCond()); + setCount(BC.BreakCount + CondCount - LoopCount); + RecordNextStmtCount = true; + } - void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { - RecordStmtCount(S); - Visit(S->getElement()); - // Counter tracks the body of the loop. - RegionCounter Cnt(PGO, S); - BreakContinueStack.push_back(BreakContinue()); - Cnt.beginRegion(); - CountMap[S->getBody()] = PGO.getCurrentRegionCount(); - Visit(S->getBody()); - BreakContinue BC = BreakContinueStack.pop_back_val(); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount); - RecordNextStmtCount = true; + void VisitForStmt(const ForStmt *S) { + RecordStmtCount(S); + if (S->getInit()) + Visit(S->getInit()); + + uint64_t ParentCount = CurrentCount; + + BreakContinueStack.push_back(BreakContinue()); + // Visit the body region first. (This is basically the same as a while + // loop; see further comments in VisitWhileStmt.) + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; + Visit(S->getBody()); + uint64_t BackedgeCount = CurrentCount; + BreakContinue BC = BreakContinueStack.pop_back_val(); + + // The increment is essentially part of the body but it needs to include + // the count for all the continue statements. + if (S->getInc()) { + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; + Visit(S->getInc()); } - void VisitSwitchStmt(const SwitchStmt *S) { - RecordStmtCount(S); + // ...then go back and propagate counts through the condition. + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + if (S->getCond()) { + CountMap[S->getCond()] = CondCount; Visit(S->getCond()); - PGO.setCurrentRegionUnreachable(); - BreakContinueStack.push_back(BreakContinue()); - Visit(S->getBody()); - // If the switch is inside a loop, add the continue counts. - BreakContinue BC = BreakContinueStack.pop_back_val(); - if (!BreakContinueStack.empty()) - BreakContinueStack.back().ContinueCount += BC.ContinueCount; - // Counter tracks the exit block of the switch. - RegionCounter ExitCnt(PGO, S); - ExitCnt.beginRegion(); - RecordNextStmtCount = true; } + setCount(BC.BreakCount + CondCount - BodyCount); + RecordNextStmtCount = true; + } - void VisitCaseStmt(const CaseStmt *S) { - RecordNextStmtCount = false; - // Counter for this particular case. This counts only jumps from the - // switch header and does not include fallthrough from the case before - // this one. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); - RecordNextStmtCount = true; - Visit(S->getSubStmt()); - } + void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { + RecordStmtCount(S); + Visit(S->getLoopVarStmt()); + Visit(S->getRangeStmt()); + Visit(S->getBeginEndStmt()); + + uint64_t ParentCount = CurrentCount; + BreakContinueStack.push_back(BreakContinue()); + // Visit the body region first. (This is basically the same as a while + // loop; see further comments in VisitWhileStmt.) + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; + Visit(S->getBody()); + uint64_t BackedgeCount = CurrentCount; + BreakContinue BC = BreakContinueStack.pop_back_val(); + + // The increment is essentially part of the body but it needs to include + // the count for all the continue statements. + uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); + CountMap[S->getInc()] = IncCount; + Visit(S->getInc()); + + // ...then go back and propagate counts through the condition. + uint64_t CondCount = + setCount(ParentCount + BackedgeCount + BC.ContinueCount); + CountMap[S->getCond()] = CondCount; + Visit(S->getCond()); + setCount(BC.BreakCount + CondCount - BodyCount); + RecordNextStmtCount = true; + } - void VisitDefaultStmt(const DefaultStmt *S) { - RecordNextStmtCount = false; - // Counter for this default case. This does not include fallthrough from - // the previous case. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(/*AddIncomingFallThrough=*/true); - CountMap[S] = Cnt.getCount(); - RecordNextStmtCount = true; - Visit(S->getSubStmt()); - } + void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { + RecordStmtCount(S); + Visit(S->getElement()); + uint64_t ParentCount = CurrentCount; + BreakContinueStack.push_back(BreakContinue()); + // Counter tracks the body of the loop. + uint64_t BodyCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getBody()] = BodyCount; + Visit(S->getBody()); + uint64_t BackedgeCount = CurrentCount; + BreakContinue BC = BreakContinueStack.pop_back_val(); + + setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - + BodyCount); + RecordNextStmtCount = true; + } - void VisitIfStmt(const IfStmt *S) { - RecordStmtCount(S); - // Counter tracks the "then" part of an if statement. The count for - // the "else" part, if it exists, will be calculated from this counter. - RegionCounter Cnt(PGO, S); - Visit(S->getCond()); + void VisitSwitchStmt(const SwitchStmt *S) { + RecordStmtCount(S); + Visit(S->getCond()); + CurrentCount = 0; + BreakContinueStack.push_back(BreakContinue()); + Visit(S->getBody()); + // If the switch is inside a loop, add the continue counts. + BreakContinue BC = BreakContinueStack.pop_back_val(); + if (!BreakContinueStack.empty()) + BreakContinueStack.back().ContinueCount += BC.ContinueCount; + // Counter tracks the exit block of the switch. + setCount(PGO.getRegionCount(S)); + RecordNextStmtCount = true; + } - Cnt.beginRegion(); - CountMap[S->getThen()] = PGO.getCurrentRegionCount(); - Visit(S->getThen()); - Cnt.adjustForControlFlow(); - - if (S->getElse()) { - Cnt.beginElseRegion(); - CountMap[S->getElse()] = PGO.getCurrentRegionCount(); - Visit(S->getElse()); - Cnt.adjustForControlFlow(); - } - Cnt.applyAdjustmentsToRegion(0); - RecordNextStmtCount = true; - } + void VisitSwitchCase(const SwitchCase *S) { + RecordNextStmtCount = false; + // Counter for this particular case. This counts only jumps from the + // switch header and does not include fallthrough from the case before + // this one. + uint64_t CaseCount = PGO.getRegionCount(S); + setCount(CurrentCount + CaseCount); + // We need the count without fallthrough in the mapping, so it's more useful + // for branch probabilities. + CountMap[S] = CaseCount; + RecordNextStmtCount = true; + Visit(S->getSubStmt()); + } - void VisitCXXTryStmt(const CXXTryStmt *S) { - RecordStmtCount(S); - Visit(S->getTryBlock()); - for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) - Visit(S->getHandler(I)); - // Counter tracks the continuation block of the try statement. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - RecordNextStmtCount = true; - } + void VisitIfStmt(const IfStmt *S) { + RecordStmtCount(S); + uint64_t ParentCount = CurrentCount; + Visit(S->getCond()); + + // Counter tracks the "then" part of an if statement. The count for + // the "else" part, if it exists, will be calculated from this counter. + uint64_t ThenCount = setCount(PGO.getRegionCount(S)); + CountMap[S->getThen()] = ThenCount; + Visit(S->getThen()); + uint64_t OutCount = CurrentCount; + + uint64_t ElseCount = ParentCount - ThenCount; + if (S->getElse()) { + setCount(ElseCount); + CountMap[S->getElse()] = ElseCount; + Visit(S->getElse()); + OutCount += CurrentCount; + } else + OutCount += ElseCount; + setCount(OutCount); + RecordNextStmtCount = true; + } - void VisitCXXCatchStmt(const CXXCatchStmt *S) { - RecordNextStmtCount = false; - // Counter tracks the catch statement's handler block. - RegionCounter Cnt(PGO, S); - Cnt.beginRegion(); - CountMap[S] = PGO.getCurrentRegionCount(); - Visit(S->getHandlerBlock()); - } + void VisitCXXTryStmt(const CXXTryStmt *S) { + RecordStmtCount(S); + Visit(S->getTryBlock()); + for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) + Visit(S->getHandler(I)); + // Counter tracks the continuation block of the try statement. + setCount(PGO.getRegionCount(S)); + RecordNextStmtCount = true; + } - void VisitAbstractConditionalOperator( - const AbstractConditionalOperator *E) { - RecordStmtCount(E); - // Counter tracks the "true" part of a conditional operator. The - // count in the "false" part will be calculated from this counter. - RegionCounter Cnt(PGO, E); - Visit(E->getCond()); - - Cnt.beginRegion(); - CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount(); - Visit(E->getTrueExpr()); - Cnt.adjustForControlFlow(); - - Cnt.beginElseRegion(); - CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount(); - Visit(E->getFalseExpr()); - Cnt.adjustForControlFlow(); - - Cnt.applyAdjustmentsToRegion(0); - RecordNextStmtCount = true; - } + void VisitCXXCatchStmt(const CXXCatchStmt *S) { + RecordNextStmtCount = false; + // Counter tracks the catch statement's handler block. + uint64_t CatchCount = setCount(PGO.getRegionCount(S)); + CountMap[S] = CatchCount; + Visit(S->getHandlerBlock()); + } - void VisitBinLAnd(const BinaryOperator *E) { - RecordStmtCount(E); - // Counter tracks the right hand side of a logical and operator. - RegionCounter Cnt(PGO, E); - Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); - Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); - RecordNextStmtCount = true; - } + void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { + RecordStmtCount(E); + uint64_t ParentCount = CurrentCount; + Visit(E->getCond()); + + // Counter tracks the "true" part of a conditional operator. The + // count in the "false" part will be calculated from this counter. + uint64_t TrueCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getTrueExpr()] = TrueCount; + Visit(E->getTrueExpr()); + uint64_t OutCount = CurrentCount; + + uint64_t FalseCount = setCount(ParentCount - TrueCount); + CountMap[E->getFalseExpr()] = FalseCount; + Visit(E->getFalseExpr()); + OutCount += CurrentCount; + + setCount(OutCount); + RecordNextStmtCount = true; + } - void VisitBinLOr(const BinaryOperator *E) { - RecordStmtCount(E); - // Counter tracks the right hand side of a logical or operator. - RegionCounter Cnt(PGO, E); - Visit(E->getLHS()); - Cnt.beginRegion(); - CountMap[E->getRHS()] = PGO.getCurrentRegionCount(); - Visit(E->getRHS()); - Cnt.adjustForControlFlow(); - Cnt.applyAdjustmentsToRegion(0); - RecordNextStmtCount = true; - } - }; + void VisitBinLAnd(const BinaryOperator *E) { + RecordStmtCount(E); + uint64_t ParentCount = CurrentCount; + Visit(E->getLHS()); + // Counter tracks the right hand side of a logical and operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; + Visit(E->getRHS()); + setCount(ParentCount + RHSCount - CurrentCount); + RecordNextStmtCount = true; + } + + void VisitBinLOr(const BinaryOperator *E) { + RecordStmtCount(E); + uint64_t ParentCount = CurrentCount; + Visit(E->getLHS()); + // Counter tracks the right hand side of a logical or operator. + uint64_t RHSCount = setCount(PGO.getRegionCount(E)); + CountMap[E->getRHS()] = RHSCount; + Visit(E->getRHS()); + setCount(ParentCount + RHSCount - CurrentCount); + RecordNextStmtCount = true; + } +}; } void PGOHash::combine(HashType Type) { @@ -728,12 +719,10 @@ void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { } void -CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName, +CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, llvm::GlobalValue::LinkageTypes Linkage) { if (SkipCoverageMapping) return; - setFuncName(FuncName, Linkage); - // Don't map the functions inside the system headers auto Loc = D->getBody()->getLocStart(); if (CGM.getContext().getSourceManager().isInSystemHeader(Loc)) @@ -750,6 +739,7 @@ CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName, if (CoverageMapping.empty()) return; + setFuncName(Name, Linkage); CGM.getCoverageMapping()->addFunctionMappingRecord( FuncNameVar, FuncName, FunctionHash, CoverageMapping); } @@ -785,17 +775,19 @@ CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, Fn->addFnAttr(llvm::Attribute::Cold); } -void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) { +void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) { if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap) return; if (!Builder.GetInsertPoint()) return; + + unsigned Counter = (*RegionCounterMap)[S]; auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); - Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), - llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), Builder.getInt64(FunctionHash), Builder.getInt32(NumRegionCounters), - Builder.getInt32(Counter)); + Builder.getInt32(Counter)}); } void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, @@ -839,8 +831,8 @@ static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { return Scaled; } -llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, - uint64_t FalseCount) { +llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, + uint64_t FalseCount) { // Check for empty weights. if (!TrueCount && !FalseCount) return nullptr; @@ -853,7 +845,8 @@ llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount, scaleBranchWeight(FalseCount, Scale)); } -llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { +llvm::MDNode * +CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { // We need at least two elements to create meaningful weights. if (Weights.size() < 2) return nullptr; @@ -875,17 +868,14 @@ llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) { return MDHelper.createBranchWeights(ScaledWeights); } -llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond, - RegionCounter &Cnt) { - if (!haveRegionCounts()) +llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, + uint64_t LoopCount) { + if (!PGO.haveRegionCounts()) return nullptr; - uint64_t LoopCount = Cnt.getCount(); - uint64_t CondCount = 0; - bool Found = getStmtCount(Cond, CondCount); - assert(Found && "missing expected loop condition count"); - (void)Found; - if (CondCount == 0) + Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); + assert(CondCount.hasValue() && "missing expected loop condition count"); + if (*CondCount == 0) return nullptr; - return createBranchWeights(LoopCount, - std::max(CondCount, LoopCount) - LoopCount); + return createProfileWeights(LoopCount, + std::max(*CondCount, LoopCount) - LoopCount); } |