Skip to content

Commit da8918f

Browse files
[SVE][NFC] Use ScalableVectorType in CGBuiltin
Summary: * Upgrade some usages of VectorType to use ScalableVectorType Reviewers: efriedma, david-arm, fpetrogalli, kmclaughlin Reviewed By: efriedma Subscribers: tschuett, rkruppe, psnobl, cfe-commits Tags: #clang Differential Revision: https://reviews.llvm.org/D78842
1 parent 3b0450a commit da8918f

File tree

2 files changed

+35
-33
lines changed

2 files changed

+35
-33
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7542,66 +7542,68 @@ llvm::Type *CodeGenFunction::getEltType(SVETypeFlags TypeFlags) {
75427542

75437543
// Return the llvm predicate vector type corresponding to the specified element
75447544
// TypeFlags.
7545-
llvm::VectorType* CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) {
7545+
llvm::ScalableVectorType *
7546+
CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) {
75467547
switch (TypeFlags.getEltType()) {
75477548
default: llvm_unreachable("Unhandled SVETypeFlag!");
75487549

75497550
case SVETypeFlags::EltTyInt8:
7550-
return llvm::VectorType::get(Builder.getInt1Ty(), { 16, true });
7551+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16);
75517552
case SVETypeFlags::EltTyInt16:
7552-
return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true });
7553+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
75537554
case SVETypeFlags::EltTyInt32:
7554-
return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true });
7555+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
75557556
case SVETypeFlags::EltTyInt64:
7556-
return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true });
7557+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
75577558

75587559
case SVETypeFlags::EltTyFloat16:
7559-
return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true });
7560+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
75607561
case SVETypeFlags::EltTyFloat32:
7561-
return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true });
7562+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
75627563
case SVETypeFlags::EltTyFloat64:
7563-
return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true });
7564+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
75647565
}
75657566
}
75667567

75677568
// Return the llvm vector type corresponding to the specified element TypeFlags.
7568-
llvm::VectorType *CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
7569+
llvm::ScalableVectorType *
7570+
CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
75697571
switch (TypeFlags.getEltType()) {
75707572
default:
75717573
llvm_unreachable("Invalid SVETypeFlag!");
75727574

75737575
case SVETypeFlags::EltTyInt8:
7574-
return llvm::VectorType::get(Builder.getInt8Ty(), {16, true});
7576+
return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
75757577
case SVETypeFlags::EltTyInt16:
7576-
return llvm::VectorType::get(Builder.getInt16Ty(), {8, true});
7578+
return llvm::ScalableVectorType::get(Builder.getInt16Ty(), 8);
75777579
case SVETypeFlags::EltTyInt32:
7578-
return llvm::VectorType::get(Builder.getInt32Ty(), {4, true});
7580+
return llvm::ScalableVectorType::get(Builder.getInt32Ty(), 4);
75797581
case SVETypeFlags::EltTyInt64:
7580-
return llvm::VectorType::get(Builder.getInt64Ty(), {2, true});
7582+
return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);
75817583

75827584
case SVETypeFlags::EltTyFloat16:
7583-
return llvm::VectorType::get(Builder.getHalfTy(), {8, true});
7585+
return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
75847586
case SVETypeFlags::EltTyFloat32:
7585-
return llvm::VectorType::get(Builder.getFloatTy(), {4, true});
7587+
return llvm::ScalableVectorType::get(Builder.getFloatTy(), 4);
75867588
case SVETypeFlags::EltTyFloat64:
7587-
return llvm::VectorType::get(Builder.getDoubleTy(), {2, true});
7589+
return llvm::ScalableVectorType::get(Builder.getDoubleTy(), 2);
75887590

75897591
case SVETypeFlags::EltTyBool8:
7590-
return llvm::VectorType::get(Builder.getInt1Ty(), {16, true});
7592+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16);
75917593
case SVETypeFlags::EltTyBool16:
7592-
return llvm::VectorType::get(Builder.getInt1Ty(), {8, true});
7594+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8);
75937595
case SVETypeFlags::EltTyBool32:
7594-
return llvm::VectorType::get(Builder.getInt1Ty(), {4, true});
7596+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4);
75957597
case SVETypeFlags::EltTyBool64:
7596-
return llvm::VectorType::get(Builder.getInt1Ty(), {2, true});
7598+
return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2);
75977599
}
75987600
}
75997601

76007602
constexpr unsigned SVEBitsPerBlock = 128;
76017603

7602-
static llvm::VectorType* getSVEVectorForElementType(llvm::Type *EltTy) {
7604+
static llvm::ScalableVectorType *getSVEVectorForElementType(llvm::Type *EltTy) {
76037605
unsigned NumElts = SVEBitsPerBlock / EltTy->getScalarSizeInBits();
7604-
return llvm::VectorType::get(EltTy, { NumElts, true });
7606+
return llvm::ScalableVectorType::get(EltTy, NumElts);
76057607
}
76067608

76077609
// Reinterpret the input predicate so that it can be used to correctly isolate
@@ -7640,8 +7642,8 @@ Value *CodeGenFunction::EmitSVEGatherLoad(SVETypeFlags TypeFlags,
76407642
SmallVectorImpl<Value *> &Ops,
76417643
unsigned IntID) {
76427644
auto *ResultTy = getSVEType(TypeFlags);
7643-
auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags),
7644-
ResultTy->getElementCount());
7645+
auto *OverloadedTy =
7646+
llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), ResultTy);
76457647

76467648
// At the ACLE level there's only one predicate type, svbool_t, which is
76477649
// mapped to <n x 16 x i1>. However, this might be incompatible with the
@@ -7692,8 +7694,8 @@ Value *CodeGenFunction::EmitSVEScatterStore(SVETypeFlags TypeFlags,
76927694
SmallVectorImpl<Value *> &Ops,
76937695
unsigned IntID) {
76947696
auto *SrcDataTy = getSVEType(TypeFlags);
7695-
auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags),
7696-
SrcDataTy->getElementCount());
7697+
auto *OverloadedTy =
7698+
llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), SrcDataTy);
76977699

76987700
// In ACLE the source data is passed in the last argument, whereas in LLVM IR
76997701
// it's the first argument. Move it accordingly.
@@ -7748,7 +7750,7 @@ Value *CodeGenFunction::EmitSVEPrefetchLoad(SVETypeFlags TypeFlags,
77487750
unsigned BuiltinID) {
77497751
auto *MemEltTy = SVEBuiltinMemEltTy(TypeFlags);
77507752
auto *VectorTy = getSVEVectorForElementType(MemEltTy);
7751-
auto *MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
7753+
auto *MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
77527754

77537755
Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
77547756
Value *BasePtr = Ops[1];
@@ -7778,8 +7780,8 @@ Value *CodeGenFunction::EmitSVEMaskedLoad(const CallExpr *E,
77787780

77797781
// The vector type that is returned may be different from the
77807782
// eventual type loaded from memory.
7781-
auto VectorTy = cast<llvm::VectorType>(ReturnTy);
7782-
auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
7783+
auto VectorTy = cast<llvm::ScalableVectorType>(ReturnTy);
7784+
auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
77837785

77847786
Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
77857787
Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo());
@@ -7803,8 +7805,8 @@ Value *CodeGenFunction::EmitSVEMaskedStore(const CallExpr *E,
78037805

78047806
// The vector type that is stored may be different from the
78057807
// eventual type stored to memory.
7806-
auto VectorTy = cast<llvm::VectorType>(Ops.back()->getType());
7807-
auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount());
7808+
auto VectorTy = cast<llvm::ScalableVectorType>(Ops.back()->getType());
7809+
auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy);
78087810

78097811
Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy);
78107812
Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo());

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3911,8 +3911,8 @@ class CodeGenFunction : public CodeGenTypeCache {
39113911
SmallVector<llvm::Type *, 2> getSVEOverloadTypes(SVETypeFlags TypeFlags,
39123912
ArrayRef<llvm::Value *> Ops);
39133913
llvm::Type *getEltType(SVETypeFlags TypeFlags);
3914-
llvm::VectorType *getSVEType(const SVETypeFlags &TypeFlags);
3915-
llvm::VectorType *getSVEPredType(SVETypeFlags TypeFlags);
3914+
llvm::ScalableVectorType *getSVEType(const SVETypeFlags &TypeFlags);
3915+
llvm::ScalableVectorType *getSVEPredType(SVETypeFlags TypeFlags);
39163916
llvm::Value *EmitSVEDupX(llvm::Value *Scalar);
39173917
llvm::Value *EmitSVEPredicateCast(llvm::Value *Pred, llvm::VectorType *VTy);
39183918
llvm::Value *EmitSVEGatherLoad(SVETypeFlags TypeFlags,

0 commit comments

Comments
 (0)