aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/FunctionComparator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/FunctionComparator.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/FunctionComparator.cpp107
1 files changed, 64 insertions, 43 deletions
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index a9b28754c8e9..101cb232d8ae 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -20,7 +20,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
@@ -52,22 +51,28 @@ using namespace llvm;
#define DEBUG_TYPE "functioncomparator"
int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const {
- if (L < R) return -1;
- if (L > R) return 1;
+ if (L < R)
+ return -1;
+ if (L > R)
+ return 1;
return 0;
}
int FunctionComparator::cmpOrderings(AtomicOrdering L, AtomicOrdering R) const {
- if ((int)L < (int)R) return -1;
- if ((int)L > (int)R) return 1;
+ if ((int)L < (int)R)
+ return -1;
+ if ((int)L > (int)R)
+ return 1;
return 0;
}
int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth()))
return Res;
- if (L.ugt(R)) return 1;
- if (R.ugt(L)) return -1;
+ if (L.ugt(R))
+ return 1;
+ if (R.ugt(L))
+ return -1;
return 0;
}
@@ -166,21 +171,17 @@ int FunctionComparator::cmpRangeMetadata(const MDNode *L,
return 0;
}
-int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L,
- const Instruction *R) const {
- ImmutableCallSite LCS(L);
- ImmutableCallSite RCS(R);
-
- assert(LCS && RCS && "Must be calls or invokes!");
- assert(LCS.isCall() == RCS.isCall() && "Can't compare otherwise!");
+int FunctionComparator::cmpOperandBundlesSchema(const CallBase &LCS,
+ const CallBase &RCS) const {
+ assert(LCS.getOpcode() == RCS.getOpcode() && "Can't compare otherwise!");
if (int Res =
cmpNumbers(LCS.getNumOperandBundles(), RCS.getNumOperandBundles()))
return Res;
- for (unsigned i = 0, e = LCS.getNumOperandBundles(); i != e; ++i) {
- auto OBL = LCS.getOperandBundleAt(i);
- auto OBR = RCS.getOperandBundleAt(i);
+ for (unsigned I = 0, E = LCS.getNumOperandBundles(); I != E; ++I) {
+ auto OBL = LCS.getOperandBundleAt(I);
+ auto OBR = RCS.getOperandBundleAt(I);
if (int Res = OBL.getTagName().compare(OBR.getTagName()))
return Res;
@@ -227,9 +228,9 @@ int FunctionComparator::cmpConstants(const Constant *L,
unsigned TyRWidth = 0;
if (auto *VecTyL = dyn_cast<VectorType>(TyL))
- TyLWidth = VecTyL->getBitWidth();
+ TyLWidth = VecTyL->getPrimitiveSizeInBits().getFixedSize();
if (auto *VecTyR = dyn_cast<VectorType>(TyR))
- TyRWidth = VecTyR->getBitWidth();
+ TyRWidth = VecTyR->getPrimitiveSizeInBits().getFixedSize();
if (TyLWidth != TyRWidth)
return cmpNumbers(TyLWidth, TyRWidth);
@@ -328,8 +329,8 @@ int FunctionComparator::cmpConstants(const Constant *L,
case Value::ConstantVectorVal: {
const ConstantVector *LV = cast<ConstantVector>(L);
const ConstantVector *RV = cast<ConstantVector>(R);
- unsigned NumElementsL = cast<VectorType>(TyL)->getNumElements();
- unsigned NumElementsR = cast<VectorType>(TyR)->getNumElements();
+ unsigned NumElementsL = cast<FixedVectorType>(TyL)->getNumElements();
+ unsigned NumElementsR = cast<FixedVectorType>(TyR)->getNumElements();
if (int Res = cmpNumbers(NumElementsL, NumElementsR))
return Res;
for (uint64_t i = 0; i < NumElementsL; ++i) {
@@ -361,12 +362,12 @@ int FunctionComparator::cmpConstants(const Constant *L,
if (LBA->getFunction() == RBA->getFunction()) {
// They are BBs in the same function. Order by which comes first in the
// BB order of the function. This order is deterministic.
- Function* F = LBA->getFunction();
+ Function *F = LBA->getFunction();
BasicBlock *LBB = LBA->getBasicBlock();
BasicBlock *RBB = RBA->getBasicBlock();
if (LBB == RBB)
return 0;
- for(BasicBlock &BB : F->getBasicBlockList()) {
+ for (BasicBlock &BB : F->getBasicBlockList()) {
if (&BB == LBB) {
assert(&BB != RBB);
return -1;
@@ -476,14 +477,25 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
return 0;
}
- case Type::ArrayTyID:
- case Type::VectorTyID: {
- auto *STyL = cast<SequentialType>(TyL);
- auto *STyR = cast<SequentialType>(TyR);
+ case Type::ArrayTyID: {
+ auto *STyL = cast<ArrayType>(TyL);
+ auto *STyR = cast<ArrayType>(TyR);
if (STyL->getNumElements() != STyR->getNumElements())
return cmpNumbers(STyL->getNumElements(), STyR->getNumElements());
return cmpTypes(STyL->getElementType(), STyR->getElementType());
}
+ case Type::FixedVectorTyID:
+ case Type::ScalableVectorTyID: {
+ auto *STyL = cast<VectorType>(TyL);
+ auto *STyR = cast<VectorType>(TyR);
+ if (STyL->getElementCount().Scalable != STyR->getElementCount().Scalable)
+ return cmpNumbers(STyL->getElementCount().Scalable,
+ STyR->getElementCount().Scalable);
+ if (STyL->getElementCount().Min != STyR->getElementCount().Min)
+ return cmpNumbers(STyL->getElementCount().Min,
+ STyR->getElementCount().Min);
+ return cmpTypes(STyL->getElementType(), STyR->getElementType());
+ }
}
}
@@ -551,7 +563,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
if (int Res = cmpNumbers(LI->getSyncScopeID(),
cast<LoadInst>(R)->getSyncScopeID()))
return Res;
- return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range),
+ return cmpRangeMetadata(
+ LI->getMetadata(LLVMContext::MD_range),
cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
}
if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
@@ -569,13 +582,13 @@ int FunctionComparator::cmpOperations(const Instruction *L,
}
if (const CmpInst *CI = dyn_cast<CmpInst>(L))
return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate());
- if (auto CSL = CallSite(const_cast<Instruction *>(L))) {
- auto CSR = CallSite(const_cast<Instruction *>(R));
- if (int Res = cmpNumbers(CSL.getCallingConv(), CSR.getCallingConv()))
+ if (auto *CBL = dyn_cast<CallBase>(L)) {
+ auto *CBR = cast<CallBase>(R);
+ if (int Res = cmpNumbers(CBL->getCallingConv(), CBR->getCallingConv()))
return Res;
- if (int Res = cmpAttrs(CSL.getAttributes(), CSR.getAttributes()))
+ if (int Res = cmpAttrs(CBL->getAttributes(), CBR->getAttributes()))
return Res;
- if (int Res = cmpOperandBundlesSchema(L, R))
+ if (int Res = cmpOperandBundlesSchema(*CBL, *CBR))
return Res;
if (const CallInst *CI = dyn_cast<CallInst>(L))
if (int Res = cmpNumbers(CI->getTailCallKind(),
@@ -616,8 +629,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
if (int Res = cmpNumbers(CXI->isVolatile(),
cast<AtomicCmpXchgInst>(R)->isVolatile()))
return Res;
- if (int Res = cmpNumbers(CXI->isWeak(),
- cast<AtomicCmpXchgInst>(R)->isWeak()))
+ if (int Res =
+ cmpNumbers(CXI->isWeak(), cast<AtomicCmpXchgInst>(R)->isWeak()))
return Res;
if (int Res =
cmpOrderings(CXI->getSuccessOrdering(),
@@ -638,11 +651,21 @@ int FunctionComparator::cmpOperations(const Instruction *L,
cast<AtomicRMWInst>(R)->isVolatile()))
return Res;
if (int Res = cmpOrderings(RMWI->getOrdering(),
- cast<AtomicRMWInst>(R)->getOrdering()))
+ cast<AtomicRMWInst>(R)->getOrdering()))
return Res;
return cmpNumbers(RMWI->getSyncScopeID(),
cast<AtomicRMWInst>(R)->getSyncScopeID());
}
+ if (const ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(L)) {
+ ArrayRef<int> LMask = SVI->getShuffleMask();
+ ArrayRef<int> RMask = cast<ShuffleVectorInst>(R)->getShuffleMask();
+ if (int Res = cmpNumbers(LMask.size(), RMask.size()))
+ return Res;
+ for (size_t i = 0, e = LMask.size(); i != e; ++i) {
+ if (int Res = cmpNumbers(LMask[i], RMask[i]))
+ return Res;
+ }
+ }
if (const PHINode *PNL = dyn_cast<PHINode>(L)) {
const PHINode *PNR = cast<PHINode>(R);
// Ensure that in addition to the incoming values being identical
@@ -675,8 +698,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
GEPR->accumulateConstantOffset(DL, OffsetR))
return cmpAPInts(OffsetL, OffsetR);
- if (int Res = cmpTypes(GEPL->getSourceElementType(),
- GEPR->getSourceElementType()))
+ if (int Res =
+ cmpTypes(GEPL->getSourceElementType(), GEPR->getSourceElementType()))
return Res;
if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands()))
@@ -829,8 +852,8 @@ int FunctionComparator::compareSignature() const {
// Visit the arguments so that they get enumerated in the order they're
// passed in.
for (Function::const_arg_iterator ArgLI = FnL->arg_begin(),
- ArgRI = FnR->arg_begin(),
- ArgLE = FnL->arg_end();
+ ArgRI = FnR->arg_begin(),
+ ArgLE = FnL->arg_end();
ArgLI != ArgLE; ++ArgLI, ++ArgRI) {
if (cmpValues(&*ArgLI, &*ArgRI) != 0)
llvm_unreachable("Arguments repeat!");
@@ -897,9 +920,7 @@ public:
// Initialize to random constant, so the state isn't zero.
HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; }
- void add(uint64_t V) {
- Hash = hashing::detail::hash_16_bytes(Hash, V);
- }
+ void add(uint64_t V) { Hash = hashing::detail::hash_16_bytes(Hash, V); }
// No finishing is required, because the entire hash value is used.
uint64_t getHash() { return Hash; }