diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/OpenMPOpt.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 141 |
1 files changed, 85 insertions, 56 deletions
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index f289e3ecc979..68f33410c602 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" @@ -153,14 +154,6 @@ static constexpr auto TAG = "[" DEBUG_TYPE "]"; namespace { -enum class AddressSpace : unsigned { - Generic = 0, - Global = 1, - Shared = 3, - Constant = 4, - Local = 5, -}; - struct AAHeapToShared; struct AAICVTracker; @@ -170,7 +163,7 @@ struct AAICVTracker; struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, - SmallPtrSetImpl<Kernel> &Kernels) + KernelSet &Kernels) : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), Kernels(Kernels) { @@ -424,6 +417,12 @@ struct OMPInformationCache : public InformationCache { recollectUsesForFunction(static_cast<RuntimeFunction>(Idx)); } + // Helper function to inherit the calling convention of the function callee. + void setCallingConvention(FunctionCallee Callee, CallInst *CI) { + if (Function *Fn = dyn_cast<Function>(Callee.getCallee())) + CI->setCallingConv(Fn->getCallingConv()); + } + /// Helper to initialize all runtime function information for those defined /// in OpenMPKinds.def. void initializeRuntimeFunctions() { @@ -485,7 +484,7 @@ struct OMPInformationCache : public InformationCache { } /// Collection of known kernels (\see Kernel) in the module. - SmallPtrSetImpl<Kernel> &Kernels; + KernelSet &Kernels; /// Collection of known OpenMP runtime functions.. DenseSet<const Function *> RTLFunctions; @@ -1013,7 +1012,8 @@ private: // into a single parallel region is contained in a single basic block // without any other instructions. We use the OpenMPIRBuilder to outline // that block and call the resulting function via __kmpc_fork_call. - auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) { + auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs, + BasicBlock *BB) { // TODO: Change the interface to allow single CIs expanded, e.g, to // include an outer loop. assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs"); @@ -1075,8 +1075,7 @@ private: BranchInst::Create(AfterBB, AfterIP.getBlock()); // Perform the actual outlining. - OMPInfoCache.OMPBuilder.finalize(OriginalFn, - /* AllowExtractorSinking */ true); + OMPInfoCache.OMPBuilder.finalize(OriginalFn); Function *OutlinedFn = MergableCIs.front()->getCaller(); @@ -1538,6 +1537,7 @@ private: CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); + OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite); RuntimeCall.eraseFromParent(); // Add "wait" runtime call declaration: @@ -1550,7 +1550,9 @@ private: OffloadArray::DeviceIDArgNum), // device_id. Handle // handle to wait on. }; - CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); + CallInst *WaitCallsite = CallInst::Create( + WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); + OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite); return true; } @@ -1597,8 +1599,10 @@ private: &F.getEntryBlock(), F.getEntryBlock().begin())); // Create a fallback location if non was found. // TODO: Use the debug locations of the calls instead. - Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(); - Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc); + uint32_t SrcLocStrSize; + Constant *Loc = + OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); + Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize); } return Ident; } @@ -2171,7 +2175,7 @@ struct AAICVTrackerFunction : public AAICVTracker { }; auto CallCheck = [&](Instruction &I) { - Optional<Value *> ReplVal = getValueForCall(A, &I, ICV); + Optional<Value *> ReplVal = getValueForCall(A, I, ICV); if (ReplVal.hasValue() && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) HasChanged = ChangeStatus::CHANGED; @@ -2197,12 +2201,12 @@ struct AAICVTrackerFunction : public AAICVTracker { return HasChanged; } - /// Hepler to check if \p I is a call and get the value for it if it is + /// Helper to check if \p I is a call and get the value for it if it is /// unique. - Optional<Value *> getValueForCall(Attributor &A, const Instruction *I, + Optional<Value *> getValueForCall(Attributor &A, const Instruction &I, InternalControlVar &ICV) const { - const auto *CB = dyn_cast<CallBase>(I); + const auto *CB = dyn_cast<CallBase>(&I); if (!CB || CB->hasFnAttr("no_openmp") || CB->hasFnAttr("no_openmp_routines")) return None; @@ -2218,8 +2222,8 @@ struct AAICVTrackerFunction : public AAICVTracker { if (CalledFunction == GetterRFI.Declaration) return None; if (CalledFunction == SetterRFI.Declaration) { - if (ICVReplacementValuesMap[ICV].count(I)) - return ICVReplacementValuesMap[ICV].lookup(I); + if (ICVReplacementValuesMap[ICV].count(&I)) + return ICVReplacementValuesMap[ICV].lookup(&I); return nullptr; } @@ -2231,8 +2235,11 @@ struct AAICVTrackerFunction : public AAICVTracker { const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); - if (ICVTrackingAA.isAssumedTracked()) - return ICVTrackingAA.getUniqueReplacementValue(ICV); + if (ICVTrackingAA.isAssumedTracked()) { + Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV); + if (!URV || (*URV && AA::isValidAtPosition(**URV, I, OMPInfoCache))) + return URV; + } // If we don't know, assume it changes. return nullptr; @@ -2284,7 +2291,7 @@ struct AAICVTrackerFunction : public AAICVTracker { break; } - Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV); + Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV); if (!NewReplVal.hasValue()) continue; @@ -2548,7 +2555,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { } /// Set of basic blocks that are executed by a single thread. - DenseSet<const BasicBlock *> SingleThreadedBBs; + SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs; /// Total number of basic blocks in this function. long unsigned NumBBs; @@ -2572,7 +2579,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (!A.checkForAllCallSites(PredForCallSite, *this, /* RequiresAllCallSites */ true, AllCallSitesKnown)) - SingleThreadedBBs.erase(&F->getEntryBlock()); + SingleThreadedBBs.remove(&F->getEntryBlock()); auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; @@ -2637,7 +2644,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { for (auto *BB : RPOT) { if (!MergePredecessorStates(BB)) - SingleThreadedBBs.erase(BB); + SingleThreadedBBs.remove(BB); } return (NumSingleThreadedBBs == SingleThreadedBBs.size()) @@ -2759,7 +2766,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { if (FreeCalls.size() != 1) continue; - ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0)); + auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0)); LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB << " with " << AllocSize->getZExtValue() @@ -2772,7 +2779,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue()); auto *SharedMem = new GlobalVariable( *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage, - UndefValue::get(Int8ArrTy), CB->getName(), nullptr, + UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr, GlobalValue::NotThreadLocal, static_cast<unsigned>(AddressSpace::Shared)); auto *NewBuffer = @@ -2786,7 +2793,10 @@ struct AAHeapToSharedFunction : public AAHeapToShared { }; A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark); - SharedMem->setAlignment(MaybeAlign(32)); + MaybeAlign Alignment = CB->getRetAlign(); + assert(Alignment && + "HeapToShared on allocation without alignment attribute"); + SharedMem->setAlignment(MaybeAlign(Alignment)); A.changeValueAfterManifest(*CB, *NewBuffer); A.deleteAfterManifest(*CB); @@ -2813,7 +2823,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { if (CallBase *CB = dyn_cast<CallBase>(U)) if (!isa<ConstantInt>(CB->getArgOperand(0)) || !ED.isExecutedByInitialThreadOnly(*CB)) - MallocCalls.erase(CB); + MallocCalls.remove(CB); } findPotentialRemovedFreeCalls(A); @@ -2825,7 +2835,7 @@ struct AAHeapToSharedFunction : public AAHeapToShared { } /// Collection of all malloc calls in a function. - SmallPtrSet<CallBase *, 4> MallocCalls; + SmallSetVector<CallBase *, 4> MallocCalls; /// Collection of potentially removed free calls in a function. SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls; }; @@ -2962,7 +2972,7 @@ struct AAKernelInfoFunction : AAKernelInfo { A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); UsedAssumedInformation = !isAtFixpoint(); auto *FalseVal = - ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0); + ConstantInt::getBool(IRP.getAnchorValue().getContext(), false); return FalseVal; }; @@ -3225,8 +3235,11 @@ struct AAKernelInfoFunction : AAKernelInfo { OpenMPIRBuilder::LocationDescription Loc( InsertPointTy(ParentBB, ParentBB->end()), DL); OMPInfoCache.OMPBuilder.updateToLocation(Loc); - auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); - Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + uint32_t SrcLocStrSize; + auto *SrcLocStr = + OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize); + Value *Ident = + OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize); BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL); // Add check for Tid in RegionCheckTidBB @@ -3237,8 +3250,10 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionCallee HardwareTidFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_get_hardware_thread_id_in_block); - Value *Tid = + CallInst *Tid = OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Tid->setDebugLoc(DL); + OMPInfoCache.setCallingConvention(HardwareTidFn, Tid); Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); OMPInfoCache.OMPBuilder.Builder .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) @@ -3251,14 +3266,18 @@ struct AAKernelInfoFunction : AAKernelInfo { M, OMPRTL___kmpc_barrier_simple_spmd); OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); - OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) - ->setDebugLoc(DL); + CallInst *Barrier = + OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}); + Barrier->setDebugLoc(DL); + OMPInfoCache.setCallingConvention(BarrierFn, Barrier); // Second barrier ensures workers have read broadcast values. - if (HasBroadcastValues) - CallInst::Create(BarrierFn, {Ident, Tid}, "", - RegionBarrierBB->getTerminator()) - ->setDebugLoc(DL); + if (HasBroadcastValues) { + CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()); + Barrier->setDebugLoc(DL); + OMPInfoCache.setCallingConvention(BarrierFn, Barrier); + } }; auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; @@ -3352,17 +3371,17 @@ struct AAKernelInfoFunction : AAKernelInfo { OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), - *ConstantInt::getBool(Ctx, 0)); + *ConstantInt::getBool(Ctx, false)); A.changeUseAfterManifest( KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_SPMD)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), - *ConstantInt::getBool(Ctx, 0)); + *ConstantInt::getBool(Ctx, false)); A.changeUseAfterManifest( KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), - *ConstantInt::getBool(Ctx, 0)); + *ConstantInt::getBool(Ctx, false)); ++NumOpenMPTargetRegionKernelsSPMD; @@ -3403,7 +3422,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // If not SPMD mode, indicate we use a custom state machine now. auto &Ctx = getAnchorValue().getContext(); - auto *FalseVal = ConstantInt::getBool(Ctx, 0); + auto *FalseVal = ConstantInt::getBool(Ctx, false); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); @@ -3528,10 +3547,12 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionCallee WarpSizeFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_get_warp_size); - Instruction *BlockHwSize = + CallInst *BlockHwSize = CallInst::Create(BlockHwSizeFn, "block.hw_size", InitBB); + OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize); BlockHwSize->setDebugLoc(DLoc); - Instruction *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB); + CallInst *WarpSize = CallInst::Create(WarpSizeFn, "warp.size", InitBB); + OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize); WarpSize->setDebugLoc(DLoc); Instruction *BlockSize = BinaryOperator::CreateSub(BlockHwSize, WarpSize, "block.size", InitBB); @@ -3571,8 +3592,10 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionCallee BarrierFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_barrier_simple_generic); - CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB) - ->setDebugLoc(DLoc); + CallInst *Barrier = + CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB); + OMPInfoCache.setCallingConvention(BarrierFn, Barrier); + Barrier->setDebugLoc(DLoc); if (WorkFnAI->getType()->getPointerAddressSpace() != (unsigned int)AddressSpace::Generic) { @@ -3588,8 +3611,9 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionCallee KernelParallelFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_kernel_parallel); - Instruction *IsActiveWorker = CallInst::Create( + CallInst *IsActiveWorker = CallInst::Create( KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB); + OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker); IsActiveWorker->setDebugLoc(DLoc); Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn", StateMachineBeginBB); @@ -3669,10 +3693,13 @@ struct AAKernelInfoFunction : AAKernelInfo { StateMachineIfCascadeCurrentBB) ->setDebugLoc(DLoc); - CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( - M, OMPRTL___kmpc_kernel_end_parallel), - {}, "", StateMachineEndParallelBB) - ->setDebugLoc(DLoc); + FunctionCallee EndParallelFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_kernel_end_parallel); + CallInst *EndParallel = + CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB); + OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel); + EndParallel->setDebugLoc(DLoc); BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB) ->setDebugLoc(DLoc); @@ -4508,6 +4535,8 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { bool UsedAssumedInformation = false; A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr, UsedAssumedInformation); + } else if (auto *SI = dyn_cast<StoreInst>(&I)) { + A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); } } } |