aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/Constants.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r--llvm/lib/IR/Constants.cpp541
1 files changed, 369 insertions, 172 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 054375aab6c3..cbbcca20ea51 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -160,8 +160,8 @@ bool Constant::isNotOneValue() const {
return !CFP->getValueAPF().bitcastToAPInt().isOneValue();
// Check that vectors don't contain 1
- if (this->getType()->isVectorTy()) {
- unsigned NumElts = this->getType()->getVectorNumElements();
+ if (auto *VTy = dyn_cast<VectorType>(this->getType())) {
+ unsigned NumElts = VTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = this->getAggregateElement(i);
if (!Elt || !Elt->isNotOneValue())
@@ -210,8 +210,8 @@ bool Constant::isNotMinSignedValue() const {
return !CFP->getValueAPF().bitcastToAPInt().isMinSignedValue();
// Check that vectors don't contain INT_MIN
- if (this->getType()->isVectorTy()) {
- unsigned NumElts = this->getType()->getVectorNumElements();
+ if (auto *VTy = dyn_cast<VectorType>(this->getType())) {
+ unsigned NumElts = VTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = this->getAggregateElement(i);
if (!Elt || !Elt->isNotMinSignedValue())
@@ -227,9 +227,10 @@ bool Constant::isNotMinSignedValue() const {
bool Constant::isFiniteNonZeroFP() const {
if (auto *CFP = dyn_cast<ConstantFP>(this))
return CFP->getValueAPF().isFiniteNonZero();
- if (!getType()->isVectorTy())
+ auto *VTy = dyn_cast<VectorType>(getType());
+ if (!VTy)
return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(this->getAggregateElement(i));
if (!CFP || !CFP->getValueAPF().isFiniteNonZero())
return false;
@@ -240,9 +241,10 @@ bool Constant::isFiniteNonZeroFP() const {
bool Constant::isNormalFP() const {
if (auto *CFP = dyn_cast<ConstantFP>(this))
return CFP->getValueAPF().isNormal();
- if (!getType()->isVectorTy())
+ auto *VTy = dyn_cast<FixedVectorType>(getType());
+ if (!VTy)
return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(this->getAggregateElement(i));
if (!CFP || !CFP->getValueAPF().isNormal())
return false;
@@ -253,9 +255,10 @@ bool Constant::isNormalFP() const {
bool Constant::hasExactInverseFP() const {
if (auto *CFP = dyn_cast<ConstantFP>(this))
return CFP->getValueAPF().getExactInverse(nullptr);
- if (!getType()->isVectorTy())
+ auto *VTy = dyn_cast<FixedVectorType>(getType());
+ if (!VTy)
return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(this->getAggregateElement(i));
if (!CFP || !CFP->getValueAPF().getExactInverse(nullptr))
return false;
@@ -266,9 +269,10 @@ bool Constant::hasExactInverseFP() const {
bool Constant::isNaN() const {
if (auto *CFP = dyn_cast<ConstantFP>(this))
return CFP->isNaN();
- if (!getType()->isVectorTy())
+ auto *VTy = dyn_cast<FixedVectorType>(getType());
+ if (!VTy)
return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(this->getAggregateElement(i));
if (!CFP || !CFP->isNaN())
return false;
@@ -282,34 +286,40 @@ bool Constant::isElementWiseEqual(Value *Y) const {
return true;
// The input value must be a vector constant with the same type.
- Type *Ty = getType();
- if (!isa<Constant>(Y) || !Ty->isVectorTy() || Ty != Y->getType())
+ auto *VTy = dyn_cast<VectorType>(getType());
+ if (!isa<Constant>(Y) || !VTy || VTy != Y->getType())
+ return false;
+
+ // TODO: Compare pointer constants?
+ if (!(VTy->getElementType()->isIntegerTy() ||
+ VTy->getElementType()->isFloatingPointTy()))
return false;
// They may still be identical element-wise (if they have `undef`s).
- // FIXME: This crashes on FP vector constants.
- return match(ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_EQ,
- const_cast<Constant *>(this),
- cast<Constant>(Y)),
- m_One());
+ // Bitcast to integer to allow exact bitwise comparison for all types.
+ Type *IntTy = VectorType::getInteger(VTy);
+ Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
+ Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
+ Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
+ return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
}
bool Constant::containsUndefElement() const {
- if (!getType()->isVectorTy())
- return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i)
- if (isa<UndefValue>(getAggregateElement(i)))
- return true;
+ if (auto *VTy = dyn_cast<VectorType>(getType())) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i)
+ if (isa<UndefValue>(getAggregateElement(i)))
+ return true;
+ }
return false;
}
bool Constant::containsConstantExpression() const {
- if (!getType()->isVectorTy())
- return false;
- for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i)
- if (isa<ConstantExpr>(getAggregateElement(i)))
- return true;
+ if (auto *VTy = dyn_cast<VectorType>(getType())) {
+ for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i)
+ if (isa<ConstantExpr>(getAggregateElement(i)))
+ return true;
+ }
return false;
}
@@ -322,6 +332,9 @@ Constant *Constant::getNullValue(Type *Ty) {
case Type::HalfTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEhalf()));
+ case Type::BFloatTyID:
+ return ConstantFP::get(Ty->getContext(),
+ APFloat::getZero(APFloat::BFloat()));
case Type::FloatTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEsingle()));
@@ -342,7 +355,8 @@ Constant *Constant::getNullValue(Type *Ty) {
return ConstantPointerNull::get(cast<PointerType>(Ty));
case Type::StructTyID:
case Type::ArrayTyID:
- case Type::VectorTyID:
+ case Type::FixedVectorTyID:
+ case Type::ScalableVectorTyID:
return ConstantAggregateZero::get(Ty);
case Type::TokenTyID:
return ConstantTokenNone::get(Ty->getContext());
@@ -364,7 +378,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) {
// Broadcast a scalar to a vector, if necessary.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- C = ConstantVector::getSplat(VTy->getNumElements(), C);
+ C = ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -375,13 +389,13 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
APInt::getAllOnesValue(ITy->getBitWidth()));
if (Ty->isFloatingPointTy()) {
- APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
- !Ty->isPPC_FP128Ty());
+ APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+ Ty->getPrimitiveSizeInBits());
return ConstantFP::get(Ty->getContext(), FL);
}
VectorType *VTy = cast<VectorType>(Ty);
- return ConstantVector::getSplat(VTy->getNumElements(),
+ return ConstantVector::getSplat(VTy->getElementCount(),
getAllOnesValue(VTy->getElementType()));
}
@@ -449,7 +463,74 @@ void Constant::destroyConstant() {
}
// Value has no outstanding references it is safe to delete it now...
- delete this;
+ deleteConstant(this);
+}
+
+void llvm::deleteConstant(Constant *C) {
+ switch (C->getValueID()) {
+ case Constant::ConstantIntVal:
+ delete static_cast<ConstantInt *>(C);
+ break;
+ case Constant::ConstantFPVal:
+ delete static_cast<ConstantFP *>(C);
+ break;
+ case Constant::ConstantAggregateZeroVal:
+ delete static_cast<ConstantAggregateZero *>(C);
+ break;
+ case Constant::ConstantArrayVal:
+ delete static_cast<ConstantArray *>(C);
+ break;
+ case Constant::ConstantStructVal:
+ delete static_cast<ConstantStruct *>(C);
+ break;
+ case Constant::ConstantVectorVal:
+ delete static_cast<ConstantVector *>(C);
+ break;
+ case Constant::ConstantPointerNullVal:
+ delete static_cast<ConstantPointerNull *>(C);
+ break;
+ case Constant::ConstantDataArrayVal:
+ delete static_cast<ConstantDataArray *>(C);
+ break;
+ case Constant::ConstantDataVectorVal:
+ delete static_cast<ConstantDataVector *>(C);
+ break;
+ case Constant::ConstantTokenNoneVal:
+ delete static_cast<ConstantTokenNone *>(C);
+ break;
+ case Constant::BlockAddressVal:
+ delete static_cast<BlockAddress *>(C);
+ break;
+ case Constant::UndefValueVal:
+ delete static_cast<UndefValue *>(C);
+ break;
+ case Constant::ConstantExprVal:
+ if (isa<UnaryConstantExpr>(C))
+ delete static_cast<UnaryConstantExpr *>(C);
+ else if (isa<BinaryConstantExpr>(C))
+ delete static_cast<BinaryConstantExpr *>(C);
+ else if (isa<SelectConstantExpr>(C))
+ delete static_cast<SelectConstantExpr *>(C);
+ else if (isa<ExtractElementConstantExpr>(C))
+ delete static_cast<ExtractElementConstantExpr *>(C);
+ else if (isa<InsertElementConstantExpr>(C))
+ delete static_cast<InsertElementConstantExpr *>(C);
+ else if (isa<ShuffleVectorConstantExpr>(C))
+ delete static_cast<ShuffleVectorConstantExpr *>(C);
+ else if (isa<ExtractValueConstantExpr>(C))
+ delete static_cast<ExtractValueConstantExpr *>(C);
+ else if (isa<InsertValueConstantExpr>(C))
+ delete static_cast<InsertValueConstantExpr *>(C);
+ else if (isa<GetElementPtrConstantExpr>(C))
+ delete static_cast<GetElementPtrConstantExpr *>(C);
+ else if (isa<CompareConstantExpr>(C))
+ delete static_cast<CompareConstantExpr *>(C);
+ else
+ llvm_unreachable("Unexpected constant expr");
+ break;
+ default:
+ llvm_unreachable("Unexpected constant");
+ }
}
static bool canTrapImpl(const Constant *C,
@@ -633,10 +714,11 @@ Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) {
}
// Don't know how to deal with this constant.
- if (!Ty->isVectorTy())
+ auto *VTy = dyn_cast<FixedVectorType>(Ty);
+ if (!VTy)
return C;
- unsigned NumElts = Ty->getVectorNumElements();
+ unsigned NumElts = VTy->getNumElements();
SmallVector<Constant *, 32> NewC(NumElts);
for (unsigned i = 0; i != NumElts; ++i) {
Constant *EltC = C->getAggregateElement(i);
@@ -675,7 +757,7 @@ Constant *ConstantInt::getTrue(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), TrueC);
+ return ConstantVector::getSplat(VTy->getElementCount(), TrueC);
return TrueC;
}
@@ -683,7 +765,7 @@ Constant *ConstantInt::getFalse(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), FalseC);
+ return ConstantVector::getSplat(VTy->getElementCount(), FalseC);
return FalseC;
}
@@ -706,7 +788,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -730,7 +812,7 @@ Constant *ConstantInt::get(Type *Ty, const APInt& V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -751,6 +833,8 @@ void ConstantInt::destroyConstantImpl() {
static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
if (Ty->isHalfTy())
return &APFloat::IEEEhalf();
+ if (Ty->isBFloatTy())
+ return &APFloat::BFloat();
if (Ty->isFloatTy())
return &APFloat::IEEEsingle();
if (Ty->isDoubleTy())
@@ -775,7 +859,7 @@ Constant *ConstantFP::get(Type *Ty, double V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -787,7 +871,7 @@ Constant *ConstantFP::get(Type *Ty, const APFloat &V) {
// For vectors, broadcast the value.
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -800,7 +884,7 @@ Constant *ConstantFP::get(Type *Ty, StringRef Str) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -811,7 +895,7 @@ Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -820,10 +904,10 @@ Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) {
const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType());
APFloat NaN = APFloat::getQNaN(Semantics, Negative, Payload);
Constant *C = get(Ty->getContext(), NaN);
-
+
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
-
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
+
return C;
}
@@ -831,10 +915,10 @@ Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) {
const fltSemantics &Semantics = *TypeToFloatSemantics(Ty->getScalarType());
APFloat NaN = APFloat::getSNaN(Semantics, Negative, Payload);
Constant *C = get(Ty->getContext(), NaN);
-
+
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
-
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
+
return C;
}
@@ -844,7 +928,7 @@ Constant *ConstantFP::getNegativeZero(Type *Ty) {
Constant *C = get(Ty->getContext(), NegZero);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -868,6 +952,8 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
Type *Ty;
if (&V.getSemantics() == &APFloat::IEEEhalf())
Ty = Type::getHalfTy(Context);
+ else if (&V.getSemantics() == &APFloat::BFloat())
+ Ty = Type::getBFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEsingle())
Ty = Type::getFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEdouble())
@@ -892,7 +978,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -917,7 +1003,9 @@ void ConstantFP::destroyConstantImpl() {
//===----------------------------------------------------------------------===//
Constant *ConstantAggregateZero::getSequentialElement() const {
- return Constant::getNullValue(getType()->getSequentialElementType());
+ if (auto *AT = dyn_cast<ArrayType>(getType()))
+ return Constant::getNullValue(AT->getElementType());
+ return Constant::getNullValue(cast<VectorType>(getType())->getElementType());
}
Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
@@ -925,13 +1013,13 @@ Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
}
Constant *ConstantAggregateZero::getElementValue(Constant *C) const {
- if (isa<SequentialType>(getType()))
+ if (isa<ArrayType>(getType()) || isa<VectorType>(getType()))
return getSequentialElement();
return getStructElement(cast<ConstantInt>(C)->getZExtValue());
}
Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
- if (isa<SequentialType>(getType()))
+ if (isa<ArrayType>(getType()) || isa<VectorType>(getType()))
return getSequentialElement();
return getStructElement(Idx);
}
@@ -950,7 +1038,9 @@ unsigned ConstantAggregateZero::getNumElements() const {
//===----------------------------------------------------------------------===//
UndefValue *UndefValue::getSequentialElement() const {
- return UndefValue::get(getType()->getSequentialElementType());
+ if (ArrayType *ATy = dyn_cast<ArrayType>(getType()))
+ return UndefValue::get(ATy->getElementType());
+ return UndefValue::get(cast<VectorType>(getType())->getElementType());
}
UndefValue *UndefValue::getStructElement(unsigned Elt) const {
@@ -958,21 +1048,23 @@ UndefValue *UndefValue::getStructElement(unsigned Elt) const {
}
UndefValue *UndefValue::getElementValue(Constant *C) const {
- if (isa<SequentialType>(getType()))
+ if (isa<ArrayType>(getType()) || isa<VectorType>(getType()))
return getSequentialElement();
return getStructElement(cast<ConstantInt>(C)->getZExtValue());
}
UndefValue *UndefValue::getElementValue(unsigned Idx) const {
- if (isa<SequentialType>(getType()))
+ if (isa<ArrayType>(getType()) || isa<VectorType>(getType()))
return getSequentialElement();
return getStructElement(Idx);
}
unsigned UndefValue::getNumElements() const {
Type *Ty = getType();
- if (auto *ST = dyn_cast<SequentialType>(Ty))
- return ST->getNumElements();
+ if (auto *AT = dyn_cast<ArrayType>(Ty))
+ return AT->getNumElements();
+ if (auto *VT = dyn_cast<VectorType>(Ty))
+ return VT->getNumElements();
return Ty->getStructNumElements();
}
@@ -1011,7 +1103,7 @@ static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) {
Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
else
return nullptr;
- return SequentialTy::getFP(V[0]->getContext(), Elts);
+ return SequentialTy::getFP(V[0]->getType(), Elts);
}
template <typename SequenceTy>
@@ -1030,7 +1122,7 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
else if (CI->getType()->isIntegerTy(64))
return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
} else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
- if (CFP->getType()->isHalfTy())
+ if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
else if (CFP->getType()->isFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
@@ -1041,19 +1133,20 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
return nullptr;
}
-ConstantAggregate::ConstantAggregate(CompositeType *T, ValueTy VT,
+ConstantAggregate::ConstantAggregate(Type *T, ValueTy VT,
ArrayRef<Constant *> V)
: Constant(T, VT, OperandTraits<ConstantAggregate>::op_end(this) - V.size(),
V.size()) {
llvm::copy(V, op_begin());
// Check that types match, unless this is an opaque struct.
- if (auto *ST = dyn_cast<StructType>(T))
+ if (auto *ST = dyn_cast<StructType>(T)) {
if (ST->isOpaque())
return;
- for (unsigned I = 0, E = V.size(); I != E; ++I)
- assert(V[I]->getType() == T->getTypeAtIndex(I) &&
- "Initializer for composite element doesn't match!");
+ for (unsigned I = 0, E = V.size(); I != E; ++I)
+ assert(V[I]->getType() == ST->getTypeAtIndex(I) &&
+ "Initializer for struct element doesn't match!");
+ }
}
ConstantArray::ConstantArray(ArrayType *T, ArrayRef<Constant *> V)
@@ -1161,13 +1254,13 @@ ConstantVector::ConstantVector(VectorType *T, ArrayRef<Constant *> V)
Constant *ConstantVector::get(ArrayRef<Constant*> V) {
if (Constant *C = getImpl(V))
return C;
- VectorType *Ty = VectorType::get(V.front()->getType(), V.size());
+ auto *Ty = FixedVectorType::get(V.front()->getType(), V.size());
return Ty->getContext().pImpl->VectorConstants.getOrCreate(Ty, V);
}
Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
assert(!V.empty() && "Vectors can't be empty");
- VectorType *T = VectorType::get(V.front()->getType(), V.size());
+ auto *T = FixedVectorType::get(V.front()->getType(), V.size());
// If this is an all-undef or all-zero vector, return a
// ConstantAggregateZero or UndefValue.
@@ -1198,15 +1291,34 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
return nullptr;
}
-Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
- // If this splat is compatible with ConstantDataVector, use it instead of
- // ConstantVector.
- if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
- ConstantDataSequential::isElementTypeCompatible(V->getType()))
- return ConstantDataVector::getSplat(NumElts, V);
+Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
+ if (!EC.Scalable) {
+ // If this splat is compatible with ConstantDataVector, use it instead of
+ // ConstantVector.
+ if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
+ ConstantDataSequential::isElementTypeCompatible(V->getType()))
+ return ConstantDataVector::getSplat(EC.Min, V);
+
+ SmallVector<Constant *, 32> Elts(EC.Min, V);
+ return get(Elts);
+ }
+
+ Type *VTy = VectorType::get(V->getType(), EC);
+
+ if (V->isNullValue())
+ return ConstantAggregateZero::get(VTy);
+ else if (isa<UndefValue>(V))
+ return UndefValue::get(VTy);
- SmallVector<Constant*, 32> Elts(NumElts, V);
- return get(Elts);
+ Type *I32Ty = Type::getInt32Ty(VTy->getContext());
+
+ // Move scalar into vector.
+ Constant *UndefV = UndefValue::get(VTy);
+ V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0));
+ // Build shuffle mask to perform the splat.
+ SmallVector<int, 8> Zeros(EC.Min, 0);
+ // Splat.
+ return ConstantExpr::getShuffleVector(V, UndefV, Zeros);
}
ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) {
@@ -1271,6 +1383,14 @@ unsigned ConstantExpr::getPredicate() const {
return cast<CompareConstantExpr>(this)->predicate;
}
+ArrayRef<int> ConstantExpr::getShuffleMask() const {
+ return cast<ShuffleVectorConstantExpr>(this)->ShuffleMask;
+}
+
+Constant *ConstantExpr::getShuffleMaskForBitcode() const {
+ return cast<ShuffleVectorConstantExpr>(this)->ShuffleMaskForBitcode;
+}
+
Constant *
ConstantExpr::getWithOperandReplaced(unsigned OpNo, Constant *Op) const {
assert(Op->getType() == getOperand(OpNo)->getType() &&
@@ -1322,7 +1442,7 @@ Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty,
case Instruction::ExtractValue:
return ConstantExpr::getExtractValue(Ops[0], getIndices(), OnlyIfReducedTy);
case Instruction::ShuffleVector:
- return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2],
+ return ConstantExpr::getShuffleVector(Ops[0], Ops[1], getShuffleMask(),
OnlyIfReducedTy);
case Instruction::GetElementPtr: {
auto *GEPO = cast<GEPOperator>(this);
@@ -1375,6 +1495,12 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
+ case Type::BFloatTyID: {
+ if (&Val2.getSemantics() == &APFloat::BFloat())
+ return true;
+ Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+ return !losesInfo;
+ }
case Type::FloatTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEsingle())
return true;
@@ -1383,6 +1509,7 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
}
case Type::DoubleTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble())
return true;
@@ -1391,16 +1518,19 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
}
case Type::X86_FP80TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::x87DoubleExtended();
case Type::FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::IEEEquad();
case Type::PPC_FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::PPCDoubleDouble();
@@ -1450,11 +1580,32 @@ void ConstantVector::destroyConstantImpl() {
Constant *Constant::getSplatValue(bool AllowUndefs) const {
assert(this->getType()->isVectorTy() && "Only valid for vectors!");
if (isa<ConstantAggregateZero>(this))
- return getNullValue(this->getType()->getVectorElementType());
+ return getNullValue(cast<VectorType>(getType())->getElementType());
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
return CV->getSplatValue();
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
return CV->getSplatValue(AllowUndefs);
+
+ // Check if this is a constant expression splat of the form returned by
+ // ConstantVector::getSplat()
+ const auto *Shuf = dyn_cast<ConstantExpr>(this);
+ if (Shuf && Shuf->getOpcode() == Instruction::ShuffleVector &&
+ isa<UndefValue>(Shuf->getOperand(1))) {
+
+ const auto *IElt = dyn_cast<ConstantExpr>(Shuf->getOperand(0));
+ if (IElt && IElt->getOpcode() == Instruction::InsertElement &&
+ isa<UndefValue>(IElt->getOperand(0))) {
+
+ ArrayRef<int> Mask = Shuf->getShuffleMask();
+ Constant *SplatVal = IElt->getOperand(1);
+ ConstantInt *Index = dyn_cast<ConstantInt>(IElt->getOperand(2));
+
+ if (Index && Index->getValue() == 0 &&
+ std::all_of(Mask.begin(), Mask.end(), [](int I) { return I == 0; }))
+ return SplatVal;
+ }
+ }
+
return nullptr;
}
@@ -1735,8 +1886,8 @@ Constant *ConstantExpr::getFPCast(Constant *C, Type *Ty) {
Constant *ConstantExpr::getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isIntOrIntVectorTy() && "Trunc operand must be integer");
@@ -1749,8 +1900,8 @@ Constant *ConstantExpr::getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getSExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isIntOrIntVectorTy() && "SExt operand must be integral");
@@ -1763,8 +1914,8 @@ Constant *ConstantExpr::getSExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getZExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isIntOrIntVectorTy() && "ZEXt operand must be integral");
@@ -1777,8 +1928,8 @@ Constant *ConstantExpr::getZExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getFPTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isFPOrFPVectorTy() && Ty->isFPOrFPVectorTy() &&
@@ -1789,8 +1940,8 @@ Constant *ConstantExpr::getFPTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getFPExtend(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isFPOrFPVectorTy() && Ty->isFPOrFPVectorTy() &&
@@ -1801,8 +1952,8 @@ Constant *ConstantExpr::getFPExtend(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getUIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isIntOrIntVectorTy() && Ty->isFPOrFPVectorTy() &&
@@ -1812,8 +1963,8 @@ Constant *ConstantExpr::getUIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getSIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isIntOrIntVectorTy() && Ty->isFPOrFPVectorTy() &&
@@ -1823,8 +1974,8 @@ Constant *ConstantExpr::getSIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getFPToUI(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isFPOrFPVectorTy() && Ty->isIntOrIntVectorTy() &&
@@ -1834,8 +1985,8 @@ Constant *ConstantExpr::getFPToUI(Constant *C, Type *Ty, bool OnlyIfReduced) {
Constant *ConstantExpr::getFPToSI(Constant *C, Type *Ty, bool OnlyIfReduced) {
#ifndef NDEBUG
- bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
- bool toVec = Ty->getTypeID() == Type::VectorTyID;
+ bool fromVec = isa<VectorType>(C->getType());
+ bool toVec = isa<VectorType>(Ty);
#endif
assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
assert(C->getType()->isFPOrFPVectorTy() && Ty->isIntOrIntVectorTy() &&
@@ -1851,7 +2002,8 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy,
"PtrToInt destination must be integer or integer vector");
assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
if (isa<VectorType>(C->getType()))
- assert(C->getType()->getVectorNumElements()==DstTy->getVectorNumElements()&&
+ assert(cast<VectorType>(C->getType())->getNumElements() ==
+ cast<VectorType>(DstTy)->getNumElements() &&
"Invalid cast between a different number of vector elements");
return getFoldedCast(Instruction::PtrToInt, C, DstTy, OnlyIfReduced);
}
@@ -1864,7 +2016,8 @@ Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy,
"IntToPtr destination must be a pointer or pointer vector");
assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
if (isa<VectorType>(C->getType()))
- assert(C->getType()->getVectorNumElements()==DstTy->getVectorNumElements()&&
+ assert(cast<VectorType>(C->getType())->getNumElements() ==
+ cast<VectorType>(DstTy)->getNumElements() &&
"Invalid cast between a different number of vector elements");
return getFoldedCast(Instruction::IntToPtr, C, DstTy, OnlyIfReduced);
}
@@ -1895,14 +2048,14 @@ Constant *ConstantExpr::getAddrSpaceCast(Constant *C, Type *DstTy,
Type *MidTy = PointerType::get(DstElemTy, SrcScalarTy->getAddressSpace());
if (VectorType *VT = dyn_cast<VectorType>(DstTy)) {
// Handle vectors of pointers.
- MidTy = VectorType::get(MidTy, VT->getNumElements());
+ MidTy = FixedVectorType::get(MidTy, VT->getNumElements());
}
C = getBitCast(C, MidTy);
}
return getFoldedCast(Instruction::AddrSpaceCast, C, DstTy, OnlyIfReduced);
}
-Constant *ConstantExpr::get(unsigned Opcode, Constant *C, unsigned Flags,
+Constant *ConstantExpr::get(unsigned Opcode, Constant *C, unsigned Flags,
Type *OnlyIfReducedTy) {
// Check the operands for consistency first.
assert(Instruction::isUnaryOp(Opcode) &&
@@ -2092,15 +2245,16 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
unsigned AS = C->getType()->getPointerAddressSpace();
Type *ReqTy = DestTy->getPointerTo(AS);
- unsigned NumVecElts = 0;
- if (C->getType()->isVectorTy())
- NumVecElts = C->getType()->getVectorNumElements();
- else for (auto Idx : Idxs)
- if (Idx->getType()->isVectorTy())
- NumVecElts = Idx->getType()->getVectorNumElements();
+ ElementCount EltCount = {0, false};
+ if (VectorType *VecTy = dyn_cast<VectorType>(C->getType()))
+ EltCount = VecTy->getElementCount();
+ else
+ for (auto Idx : Idxs)
+ if (VectorType *VecTy = dyn_cast<VectorType>(Idx->getType()))
+ EltCount = VecTy->getElementCount();
- if (NumVecElts)
- ReqTy = VectorType::get(ReqTy, NumVecElts);
+ if (EltCount.Min != 0)
+ ReqTy = VectorType::get(ReqTy, EltCount);
if (OnlyIfReducedTy == ReqTy)
return nullptr;
@@ -2109,14 +2263,20 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
std::vector<Constant*> ArgVec;
ArgVec.reserve(1 + Idxs.size());
ArgVec.push_back(C);
- for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
- assert((!Idxs[i]->getType()->isVectorTy() ||
- Idxs[i]->getType()->getVectorNumElements() == NumVecElts) &&
- "getelementptr index type missmatch");
-
- Constant *Idx = cast<Constant>(Idxs[i]);
- if (NumVecElts && !Idxs[i]->getType()->isVectorTy())
- Idx = ConstantVector::getSplat(NumVecElts, Idx);
+ auto GTI = gep_type_begin(Ty, Idxs), GTE = gep_type_end(Ty, Idxs);
+ for (; GTI != GTE; ++GTI) {
+ auto *Idx = cast<Constant>(GTI.getOperand());
+ assert(
+ (!isa<VectorType>(Idx->getType()) ||
+ cast<VectorType>(Idx->getType())->getElementCount() == EltCount) &&
+ "getelementptr index type missmatch");
+
+ if (GTI.isStruct() && Idx->getType()->isVectorTy()) {
+ Idx = Idx->getSplatValue();
+ } else if (GTI.isSequential() && EltCount.Min != 0 &&
+ !Idx->getType()->isVectorTy()) {
+ Idx = ConstantVector::getSplat(EltCount, Idx);
+ }
ArgVec.push_back(Idx);
}
@@ -2124,7 +2284,7 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
if (InRangeIndex && *InRangeIndex < 63)
SubClassOptionalData |= (*InRangeIndex + 1) << 1;
const ConstantExprKeyType Key(Instruction::GetElementPtr, ArgVec, 0,
- SubClassOptionalData, None, Ty);
+ SubClassOptionalData, None, None, Ty);
LLVMContextImpl *pImpl = C->getContext().pImpl;
return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
@@ -2149,7 +2309,7 @@ Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS,
Type *ResultTy = Type::getInt1Ty(LHS->getContext());
if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))
- ResultTy = VectorType::get(ResultTy, VT->getNumElements());
+ ResultTy = VectorType::get(ResultTy, VT->getElementCount());
LLVMContextImpl *pImpl = LHS->getType()->getContext().pImpl;
return pImpl->ExprConstants.getOrCreate(ResultTy, Key);
@@ -2174,7 +2334,7 @@ Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS,
Type *ResultTy = Type::getInt1Ty(LHS->getContext());
if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))
- ResultTy = VectorType::get(ResultTy, VT->getNumElements());
+ ResultTy = VectorType::get(ResultTy, VT->getElementCount());
LLVMContextImpl *pImpl = LHS->getType()->getContext().pImpl;
return pImpl->ExprConstants.getOrCreate(ResultTy, Key);
@@ -2190,7 +2350,7 @@ Constant *ConstantExpr::getExtractElement(Constant *Val, Constant *Idx,
if (Constant *FC = ConstantFoldExtractElementInstruction(Val, Idx))
return FC; // Fold a few common cases.
- Type *ReqTy = Val->getType()->getVectorElementType();
+ Type *ReqTy = cast<VectorType>(Val->getType())->getElementType();
if (OnlyIfReducedTy == ReqTy)
return nullptr;
@@ -2206,7 +2366,7 @@ Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt,
Constant *Idx, Type *OnlyIfReducedTy) {
assert(Val->getType()->isVectorTy() &&
"Tried to create insertelement operation on non-vector type!");
- assert(Elt->getType() == Val->getType()->getVectorElementType() &&
+ assert(Elt->getType() == cast<VectorType>(Val->getType())->getElementType() &&
"Insertelement types must match!");
assert(Idx->getType()->isIntegerTy() &&
"Insertelement index must be i32 type!");
@@ -2226,23 +2386,26 @@ Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt,
}
Constant *ConstantExpr::getShuffleVector(Constant *V1, Constant *V2,
- Constant *Mask, Type *OnlyIfReducedTy) {
+ ArrayRef<int> Mask,
+ Type *OnlyIfReducedTy) {
assert(ShuffleVectorInst::isValidOperands(V1, V2, Mask) &&
"Invalid shuffle vector constant expr operands!");
if (Constant *FC = ConstantFoldShuffleVectorInstruction(V1, V2, Mask))
return FC; // Fold a few common cases.
- ElementCount NElts = Mask->getType()->getVectorElementCount();
- Type *EltTy = V1->getType()->getVectorElementType();
- Type *ShufTy = VectorType::get(EltTy, NElts);
+ unsigned NElts = Mask.size();
+ auto V1VTy = cast<VectorType>(V1->getType());
+ Type *EltTy = V1VTy->getElementType();
+ bool TypeIsScalable = isa<ScalableVectorType>(V1VTy);
+ Type *ShufTy = VectorType::get(EltTy, NElts, TypeIsScalable);
if (OnlyIfReducedTy == ShufTy)
return nullptr;
// Look up the constant in the table first to ensure uniqueness
- Constant *ArgVec[] = { V1, V2, Mask };
- const ConstantExprKeyType Key(Instruction::ShuffleVector, ArgVec);
+ Constant *ArgVec[] = {V1, V2};
+ ConstantExprKeyType Key(Instruction::ShuffleVector, ArgVec, 0, 0, None, Mask);
LLVMContextImpl *pImpl = ShufTy->getContext().pImpl;
return pImpl->ExprConstants.getOrCreate(ShufTy, Key);
@@ -2499,7 +2662,9 @@ Type *GetElementPtrConstantExpr::getResultElementType() const {
// ConstantData* implementations
Type *ConstantDataSequential::getElementType() const {
- return getType()->getElementType();
+ if (ArrayType *ATy = dyn_cast<ArrayType>(getType()))
+ return ATy->getElementType();
+ return cast<VectorType>(getType())->getElementType();
}
StringRef ConstantDataSequential::getRawDataValues() const {
@@ -2507,7 +2672,8 @@ StringRef ConstantDataSequential::getRawDataValues() const {
}
bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
- if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+ if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+ return true;
if (auto *IT = dyn_cast<IntegerType>(Ty)) {
switch (IT->getBitWidth()) {
case 8:
@@ -2524,7 +2690,7 @@ bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
unsigned ConstantDataSequential::getNumElements() const {
if (ArrayType *AT = dyn_cast<ArrayType>(getType()))
return AT->getNumElements();
- return getType()->getVectorNumElements();
+ return cast<VectorType>(getType())->getNumElements();
}
@@ -2552,7 +2718,12 @@ static bool isAllZeros(StringRef Arr) {
/// the correct element type. We take the bytes in as a StringRef because
/// we *want* an underlying "char*" to avoid TBAA type punning violations.
Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
- assert(isElementTypeCompatible(Ty->getSequentialElementType()));
+#ifndef NDEBUG
+ if (ArrayType *ATy = dyn_cast<ArrayType>(Ty))
+ assert(isElementTypeCompatible(ATy->getElementType()));
+ else
+ assert(isElementTypeCompatible(cast<VectorType>(Ty)->getElementType()));
+#endif
// If the elements are all zero or there are no elements, return a CAZ, which
// is more dense and canonical.
if (isAllZeros(Elements))
@@ -2620,26 +2791,29 @@ void ConstantDataSequential::destroyConstantImpl() {
Next = nullptr;
}
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint16_t> Elts) {
- Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint32_t> Elts) {
- Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint64_t> Elts) {
- Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
@@ -2661,56 +2835,62 @@ Constant *ConstantDataArray::getString(LLVMContext &Context,
/// count and element type matching the ArrayRef passed in. Note that this
/// can return a ConstantAggregateZero object.
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint8_t> Elts){
- Type *Ty = VectorType::get(Type::getInt8Ty(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getInt8Ty(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 1), Ty);
}
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint16_t> Elts){
- Type *Ty = VectorType::get(Type::getInt16Ty(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getInt16Ty(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint32_t> Elts){
- Type *Ty = VectorType::get(Type::getInt32Ty(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getInt32Ty(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint64_t> Elts){
- Type *Ty = VectorType::get(Type::getInt64Ty(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getInt64Ty(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<float> Elts) {
- Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getFloatTy(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
- Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+ auto *Ty = FixedVectorType::get(Type::getDoubleTy(Context), Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint16_t> Elts) {
- Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ auto *Ty = FixedVectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint32_t> Elts) {
- Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ auto *Ty = FixedVectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint64_t> Elts) {
- Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ auto *Ty = FixedVectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
@@ -2740,20 +2920,25 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
if (CFP->getType()->isHalfTy()) {
SmallVector<uint16_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
+ }
+ if (CFP->getType()->isBFloatTy()) {
+ SmallVector<uint16_t, 16> Elts(
+ NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isFloatTy()) {
SmallVector<uint32_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isDoubleTy()) {
SmallVector<uint64_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
}
- return ConstantVector::getSplat(NumElts, V);
+ return ConstantVector::getSplat({NumElts, false}, V);
}
@@ -2815,6 +3000,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
}
+ case Type::BFloatTyID: {
+ auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+ return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+ }
case Type::FloatTyID: {
auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
@@ -2839,8 +3028,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
}
Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
- if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
- getElementType()->isDoubleTy())
+ if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+ getElementType()->isFloatTy() || getElementType()->isDoubleTy())
return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
@@ -2863,7 +3052,7 @@ bool ConstantDataSequential::isCString() const {
return Str.drop_back().find(0) == StringRef::npos;
}
-bool ConstantDataVector::isSplat() const {
+bool ConstantDataVector::isSplatData() const {
const char *Base = getRawDataValues().data();
// Compare elements 1+ to the 0'th element.
@@ -2875,6 +3064,14 @@ bool ConstantDataVector::isSplat() const {
return true;
}
+bool ConstantDataVector::isSplat() const {
+ if (!IsSplatSet) {
+ IsSplatSet = true;
+ IsSplat = isSplatData();
+ }
+ return IsSplat;
+}
+
Constant *ConstantDataVector::getSplatValue() const {
// If they're all the same, return the 0th one as a representative.
return isSplat() ? getElementAsConstant(0) : nullptr;
@@ -3081,7 +3278,7 @@ Instruction *ConstantExpr::getAsInstruction() const {
case Instruction::ExtractValue:
return ExtractValueInst::Create(Ops[0], getIndices());
case Instruction::ShuffleVector:
- return new ShuffleVectorInst(Ops[0], Ops[1], Ops[2]);
+ return new ShuffleVectorInst(Ops[0], Ops[1], getShuffleMask());
case Instruction::GetElementPtr: {
const auto *GO = cast<GEPOperator>(this);