diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/FunctionComparator.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/FunctionComparator.cpp | 107 |
1 files changed, 64 insertions, 43 deletions
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp index a9b28754c8e9..101cb232d8ae 100644 --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -52,22 +51,28 @@ using namespace llvm; #define DEBUG_TYPE "functioncomparator" int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { - if (L < R) return -1; - if (L > R) return 1; + if (L < R) + return -1; + if (L > R) + return 1; return 0; } int FunctionComparator::cmpOrderings(AtomicOrdering L, AtomicOrdering R) const { - if ((int)L < (int)R) return -1; - if ((int)L > (int)R) return 1; + if ((int)L < (int)R) + return -1; + if ((int)L > (int)R) + return 1; return 0; } int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth())) return Res; - if (L.ugt(R)) return 1; - if (R.ugt(L)) return -1; + if (L.ugt(R)) + return 1; + if (R.ugt(L)) + return -1; return 0; } @@ -166,21 +171,17 @@ int FunctionComparator::cmpRangeMetadata(const MDNode *L, return 0; } -int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, - const Instruction *R) const { - ImmutableCallSite LCS(L); - ImmutableCallSite RCS(R); - - assert(LCS && RCS && "Must be calls or invokes!"); - assert(LCS.isCall() == RCS.isCall() && "Can't compare otherwise!"); +int FunctionComparator::cmpOperandBundlesSchema(const CallBase &LCS, + const CallBase &RCS) const { + assert(LCS.getOpcode() == RCS.getOpcode() && "Can't compare otherwise!"); if (int Res = cmpNumbers(LCS.getNumOperandBundles(), RCS.getNumOperandBundles())) return Res; - for (unsigned i = 0, e = LCS.getNumOperandBundles(); i != e; ++i) { - auto OBL = LCS.getOperandBundleAt(i); - auto OBR = RCS.getOperandBundleAt(i); + for (unsigned I = 0, E = LCS.getNumOperandBundles(); I != E; ++I) { + auto OBL = LCS.getOperandBundleAt(I); + auto OBR = RCS.getOperandBundleAt(I); if (int Res = OBL.getTagName().compare(OBR.getTagName())) return Res; @@ -227,9 +228,9 @@ int FunctionComparator::cmpConstants(const Constant *L, unsigned TyRWidth = 0; if (auto *VecTyL = dyn_cast<VectorType>(TyL)) - TyLWidth = VecTyL->getBitWidth(); + TyLWidth = VecTyL->getPrimitiveSizeInBits().getFixedSize(); if (auto *VecTyR = dyn_cast<VectorType>(TyR)) - TyRWidth = VecTyR->getBitWidth(); + TyRWidth = VecTyR->getPrimitiveSizeInBits().getFixedSize(); if (TyLWidth != TyRWidth) return cmpNumbers(TyLWidth, TyRWidth); @@ -328,8 +329,8 @@ int FunctionComparator::cmpConstants(const Constant *L, case Value::ConstantVectorVal: { const ConstantVector *LV = cast<ConstantVector>(L); const ConstantVector *RV = cast<ConstantVector>(R); - unsigned NumElementsL = cast<VectorType>(TyL)->getNumElements(); - unsigned NumElementsR = cast<VectorType>(TyR)->getNumElements(); + unsigned NumElementsL = cast<FixedVectorType>(TyL)->getNumElements(); + unsigned NumElementsR = cast<FixedVectorType>(TyR)->getNumElements(); if (int Res = cmpNumbers(NumElementsL, NumElementsR)) return Res; for (uint64_t i = 0; i < NumElementsL; ++i) { @@ -361,12 +362,12 @@ int FunctionComparator::cmpConstants(const Constant *L, if (LBA->getFunction() == RBA->getFunction()) { // They are BBs in the same function. Order by which comes first in the // BB order of the function. This order is deterministic. - Function* F = LBA->getFunction(); + Function *F = LBA->getFunction(); BasicBlock *LBB = LBA->getBasicBlock(); BasicBlock *RBB = RBA->getBasicBlock(); if (LBB == RBB) return 0; - for(BasicBlock &BB : F->getBasicBlockList()) { + for (BasicBlock &BB : F->getBasicBlockList()) { if (&BB == LBB) { assert(&BB != RBB); return -1; @@ -476,14 +477,25 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { return 0; } - case Type::ArrayTyID: - case Type::VectorTyID: { - auto *STyL = cast<SequentialType>(TyL); - auto *STyR = cast<SequentialType>(TyR); + case Type::ArrayTyID: { + auto *STyL = cast<ArrayType>(TyL); + auto *STyR = cast<ArrayType>(TyR); if (STyL->getNumElements() != STyR->getNumElements()) return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); return cmpTypes(STyL->getElementType(), STyR->getElementType()); } + case Type::FixedVectorTyID: + case Type::ScalableVectorTyID: { + auto *STyL = cast<VectorType>(TyL); + auto *STyR = cast<VectorType>(TyR); + if (STyL->getElementCount().Scalable != STyR->getElementCount().Scalable) + return cmpNumbers(STyL->getElementCount().Scalable, + STyR->getElementCount().Scalable); + if (STyL->getElementCount().Min != STyR->getElementCount().Min) + return cmpNumbers(STyL->getElementCount().Min, + STyR->getElementCount().Min); + return cmpTypes(STyL->getElementType(), STyR->getElementType()); + } } } @@ -551,7 +563,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(LI->getSyncScopeID(), cast<LoadInst>(R)->getSyncScopeID())) return Res; - return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range), + return cmpRangeMetadata( + LI->getMetadata(LLVMContext::MD_range), cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); } if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { @@ -569,13 +582,13 @@ int FunctionComparator::cmpOperations(const Instruction *L, } if (const CmpInst *CI = dyn_cast<CmpInst>(L)) return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate()); - if (auto CSL = CallSite(const_cast<Instruction *>(L))) { - auto CSR = CallSite(const_cast<Instruction *>(R)); - if (int Res = cmpNumbers(CSL.getCallingConv(), CSR.getCallingConv())) + if (auto *CBL = dyn_cast<CallBase>(L)) { + auto *CBR = cast<CallBase>(R); + if (int Res = cmpNumbers(CBL->getCallingConv(), CBR->getCallingConv())) return Res; - if (int Res = cmpAttrs(CSL.getAttributes(), CSR.getAttributes())) + if (int Res = cmpAttrs(CBL->getAttributes(), CBR->getAttributes())) return Res; - if (int Res = cmpOperandBundlesSchema(L, R)) + if (int Res = cmpOperandBundlesSchema(*CBL, *CBR)) return Res; if (const CallInst *CI = dyn_cast<CallInst>(L)) if (int Res = cmpNumbers(CI->getTailCallKind(), @@ -616,8 +629,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(CXI->isVolatile(), cast<AtomicCmpXchgInst>(R)->isVolatile())) return Res; - if (int Res = cmpNumbers(CXI->isWeak(), - cast<AtomicCmpXchgInst>(R)->isWeak())) + if (int Res = + cmpNumbers(CXI->isWeak(), cast<AtomicCmpXchgInst>(R)->isWeak())) return Res; if (int Res = cmpOrderings(CXI->getSuccessOrdering(), @@ -638,11 +651,21 @@ int FunctionComparator::cmpOperations(const Instruction *L, cast<AtomicRMWInst>(R)->isVolatile())) return Res; if (int Res = cmpOrderings(RMWI->getOrdering(), - cast<AtomicRMWInst>(R)->getOrdering())) + cast<AtomicRMWInst>(R)->getOrdering())) return Res; return cmpNumbers(RMWI->getSyncScopeID(), cast<AtomicRMWInst>(R)->getSyncScopeID()); } + if (const ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(L)) { + ArrayRef<int> LMask = SVI->getShuffleMask(); + ArrayRef<int> RMask = cast<ShuffleVectorInst>(R)->getShuffleMask(); + if (int Res = cmpNumbers(LMask.size(), RMask.size())) + return Res; + for (size_t i = 0, e = LMask.size(); i != e; ++i) { + if (int Res = cmpNumbers(LMask[i], RMask[i])) + return Res; + } + } if (const PHINode *PNL = dyn_cast<PHINode>(L)) { const PHINode *PNR = cast<PHINode>(R); // Ensure that in addition to the incoming values being identical @@ -675,8 +698,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, if (GEPL->accumulateConstantOffset(DL, OffsetL) && GEPR->accumulateConstantOffset(DL, OffsetR)) return cmpAPInts(OffsetL, OffsetR); - if (int Res = cmpTypes(GEPL->getSourceElementType(), - GEPR->getSourceElementType())) + if (int Res = + cmpTypes(GEPL->getSourceElementType(), GEPR->getSourceElementType())) return Res; if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands())) @@ -829,8 +852,8 @@ int FunctionComparator::compareSignature() const { // Visit the arguments so that they get enumerated in the order they're // passed in. for (Function::const_arg_iterator ArgLI = FnL->arg_begin(), - ArgRI = FnR->arg_begin(), - ArgLE = FnL->arg_end(); + ArgRI = FnR->arg_begin(), + ArgLE = FnL->arg_end(); ArgLI != ArgLE; ++ArgLI, ++ArgRI) { if (cmpValues(&*ArgLI, &*ArgRI) != 0) llvm_unreachable("Arguments repeat!"); @@ -897,9 +920,7 @@ public: // Initialize to random constant, so the state isn't zero. HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } - void add(uint64_t V) { - Hash = hashing::detail::hash_16_bytes(Hash, V); - } + void add(uint64_t V) { Hash = hashing::detail::hash_16_bytes(Hash, V); } // No finishing is required, because the entire hash value is used. uint64_t getHash() { return Hash; } |