Skip to content

Commit f4c67df

Browse files
committed
[X86] More accurately model the cost of horizontal reductions.
This patch attempts to more accurately model the reduction of power of 2 vectors of types we natively support. This takes into account the narrowing of vectors that occur as we go from 512 bits to 256 bits, to 128 bits. It also takes into account the use of wider elements in the shuffles for the first 2 steps of a reduction from 128 bits. And uses a v8i16 shift for the final step of vXi8 reduction. The default implementation uses the legalized type for the arithmetic for all levels. And uses the single source permute cost of the legalized type for all levels. This penalizes things like lack of v16i8 pshufb on pre-sse3 targets and the splitting and joining that needs to be done for integer types on AVX1. We never need v16i8 shuffle for a reduction and we only need split AVX1 ops when type the type wide and needs to be split. I think we're still over costing splits and joins for AVX1, but we're closer now. I've also removed all pairwise special casing because I don't think we ever want to generate that on X86. I've also adjusted the add handling to more accurately account for any type splitting that occurs before we reach a legal type. Differential Revision: https://reviews.llvm.org/D76478
1 parent 7cfd5de commit f4c67df

File tree

8 files changed

+551
-736
lines changed

8 files changed

+551
-736
lines changed

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 102 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,39 +2648,13 @@ int X86TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
26482648

26492649
int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
26502650
bool IsPairwise) {
2651+
// Just use the default implementation for pair reductions.
2652+
if (IsPairwise)
2653+
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
2654+
26512655
// We use the Intel Architecture Code Analyzer(IACA) to measure the throughput
26522656
// and make it as the cost.
26532657

2654-
static const CostTblEntry SLMCostTblPairWise[] = {
2655-
{ ISD::FADD, MVT::v2f64, 3 },
2656-
{ ISD::ADD, MVT::v2i64, 5 },
2657-
};
2658-
2659-
static const CostTblEntry SSE2CostTblPairWise[] = {
2660-
{ ISD::FADD, MVT::v2f64, 2 },
2661-
{ ISD::FADD, MVT::v4f32, 4 },
2662-
{ ISD::ADD, MVT::v2i64, 2 }, // The data reported by the IACA tool is "1.6".
2663-
{ ISD::ADD, MVT::v2i32, 2 }, // FIXME: chosen to be less than v4i32.
2664-
{ ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.5".
2665-
{ ISD::ADD, MVT::v2i16, 3 }, // FIXME: chosen to be less than v4i16
2666-
{ ISD::ADD, MVT::v4i16, 4 }, // FIXME: chosen to be less than v8i16
2667-
{ ISD::ADD, MVT::v8i16, 5 },
2668-
{ ISD::ADD, MVT::v2i8, 2 },
2669-
{ ISD::ADD, MVT::v4i8, 2 },
2670-
{ ISD::ADD, MVT::v8i8, 2 },
2671-
{ ISD::ADD, MVT::v16i8, 3 },
2672-
};
2673-
2674-
static const CostTblEntry AVX1CostTblPairWise[] = {
2675-
{ ISD::FADD, MVT::v4f64, 5 },
2676-
{ ISD::FADD, MVT::v8f32, 7 },
2677-
{ ISD::ADD, MVT::v2i64, 1 }, // The data reported by the IACA tool is "1.5".
2678-
{ ISD::ADD, MVT::v4i64, 5 }, // The data reported by the IACA tool is "4.8".
2679-
{ ISD::ADD, MVT::v8i32, 5 },
2680-
{ ISD::ADD, MVT::v16i16, 6 },
2681-
{ ISD::ADD, MVT::v32i8, 4 },
2682-
};
2683-
26842658
static const CostTblEntry SLMCostTblNoPairWise[] = {
26852659
{ ISD::FADD, MVT::v2f64, 3 },
26862660
{ ISD::ADD, MVT::v2i64, 5 },
@@ -2721,62 +2695,44 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
27212695
EVT VT = TLI->getValueType(DL, ValTy);
27222696
if (VT.isSimple()) {
27232697
MVT MTy = VT.getSimpleVT();
2724-
if (IsPairwise) {
2725-
if (ST->isSLM())
2726-
if (const auto *Entry = CostTableLookup(SLMCostTblPairWise, ISD, MTy))
2727-
return Entry->Cost;
2728-
2729-
if (ST->hasAVX())
2730-
if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy))
2731-
return Entry->Cost;
2732-
2733-
if (ST->hasSSE2())
2734-
if (const auto *Entry = CostTableLookup(SSE2CostTblPairWise, ISD, MTy))
2735-
return Entry->Cost;
2736-
} else {
2737-
if (ST->isSLM())
2738-
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
2739-
return Entry->Cost;
2698+
if (ST->isSLM())
2699+
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
2700+
return Entry->Cost;
27402701

2741-
if (ST->hasAVX())
2742-
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
2743-
return Entry->Cost;
2702+
if (ST->hasAVX())
2703+
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
2704+
return Entry->Cost;
27442705

2745-
if (ST->hasSSE2())
2746-
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
2747-
return Entry->Cost;
2748-
}
2706+
if (ST->hasSSE2())
2707+
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
2708+
return Entry->Cost;
27492709
}
27502710

27512711
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);
27522712

27532713
MVT MTy = LT.second;
27542714

2755-
if (IsPairwise) {
2756-
if (ST->isSLM())
2757-
if (const auto *Entry = CostTableLookup(SLMCostTblPairWise, ISD, MTy))
2758-
return LT.first * Entry->Cost;
2759-
2760-
if (ST->hasAVX())
2761-
if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy))
2762-
return LT.first * Entry->Cost;
2715+
unsigned ArithmeticCost = 0;
2716+
if (LT.first != 1 && MTy.isVector() &&
2717+
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
2718+
// Type needs to be split. We need LT.first - 1 arithmetic ops.
2719+
Type *SingleOpTy = VectorType::get(ValTy->getVectorElementType(),
2720+
MTy.getVectorNumElements());
2721+
ArithmeticCost = getArithmeticInstrCost(Opcode, SingleOpTy);
2722+
ArithmeticCost *= LT.first - 1;
2723+
}
27632724

2764-
if (ST->hasSSE2())
2765-
if (const auto *Entry = CostTableLookup(SSE2CostTblPairWise, ISD, MTy))
2766-
return LT.first * Entry->Cost;
2767-
} else {
2768-
if (ST->isSLM())
2769-
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
2770-
return LT.first * Entry->Cost;
2725+
if (ST->isSLM())
2726+
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
2727+
return ArithmeticCost + Entry->Cost;
27712728

2772-
if (ST->hasAVX())
2773-
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
2774-
return LT.first * Entry->Cost;
2729+
if (ST->hasAVX())
2730+
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
2731+
return ArithmeticCost + Entry->Cost;
27752732

2776-
if (ST->hasSSE2())
2777-
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
2778-
return LT.first * Entry->Cost;
2779-
}
2733+
if (ST->hasSSE2())
2734+
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
2735+
return ArithmeticCost + Entry->Cost;
27802736

27812737
// FIXME: These assume a naive kshift+binop lowering, which is probably
27822738
// conservative in most cases.
@@ -2825,9 +2781,9 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
28252781
};
28262782

28272783
// Handle bool allof/anyof patterns.
2828-
if (!IsPairwise && ValTy->getVectorElementType()->isIntegerTy(1)) {
2784+
if (ValTy->getVectorElementType()->isIntegerTy(1)) {
28292785
unsigned ArithmeticCost = 0;
2830-
if (MTy.isVector() &&
2786+
if (LT.first != 1 && MTy.isVector() &&
28312787
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
28322788
// Type needs to be split. We need LT.first - 1 arithmetic ops.
28332789
Type *SingleOpTy = VectorType::get(ValTy->getVectorElementType(),
@@ -2848,9 +2804,77 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
28482804
if (ST->hasSSE2())
28492805
if (const auto *Entry = CostTableLookup(SSE2BoolReduction, ISD, MTy))
28502806
return ArithmeticCost + Entry->Cost;
2807+
2808+
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
2809+
}
2810+
2811+
unsigned NumVecElts = ValTy->getVectorNumElements();
2812+
unsigned ScalarSize = ValTy->getScalarSizeInBits();
2813+
2814+
// Special case power of 2 reductions where the scalar type isn't changed
2815+
// by type legalization.
2816+
if (!isPowerOf2_32(NumVecElts) || ScalarSize != MTy.getScalarSizeInBits())
2817+
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
2818+
2819+
unsigned ReductionCost = 0;
2820+
2821+
Type *Ty = ValTy;
2822+
if (LT.first != 1 && MTy.isVector() &&
2823+
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
2824+
// Type needs to be split. We need LT.first - 1 arithmetic ops.
2825+
Ty = VectorType::get(ValTy->getVectorElementType(),
2826+
MTy.getVectorNumElements());
2827+
ReductionCost = getArithmeticInstrCost(Opcode, Ty);
2828+
ReductionCost *= LT.first - 1;
2829+
NumVecElts = MTy.getVectorNumElements();
2830+
}
2831+
2832+
// Now handle reduction with the legal type, taking into account size changes
2833+
// at each level.
2834+
while (NumVecElts > 1) {
2835+
// Determine the size of the remaining vector we need to reduce.
2836+
unsigned Size = NumVecElts * ScalarSize;
2837+
NumVecElts /= 2;
2838+
// If we're reducing from 256/512 bits, use an extract_subvector.
2839+
if (Size > 128) {
2840+
Type *SubTy = VectorType::get(ValTy->getVectorElementType(), NumVecElts);
2841+
ReductionCost +=
2842+
getShuffleCost(TTI::SK_ExtractSubvector, Ty, NumVecElts, SubTy);
2843+
Ty = SubTy;
2844+
} else if (Size == 128) {
2845+
// Reducing from 128 bits is a permute of v2f64/v2i64.
2846+
Type *ShufTy;
2847+
if (ValTy->isFloatingPointTy())
2848+
ShufTy = VectorType::get(Type::getDoubleTy(ValTy->getContext()), 2);
2849+
else
2850+
ShufTy = VectorType::get(Type::getInt64Ty(ValTy->getContext()), 2);
2851+
ReductionCost +=
2852+
getShuffleCost(TTI::SK_PermuteSingleSrc, ShufTy, 0, nullptr);
2853+
} else if (Size == 64) {
2854+
// Reducing from 64 bits is a shuffle of v4f32/v4i32.
2855+
Type *ShufTy;
2856+
if (ValTy->isFloatingPointTy())
2857+
ShufTy = VectorType::get(Type::getFloatTy(ValTy->getContext()), 4);
2858+
else
2859+
ShufTy = VectorType::get(Type::getInt32Ty(ValTy->getContext()), 4);
2860+
ReductionCost +=
2861+
getShuffleCost(TTI::SK_PermuteSingleSrc, ShufTy, 0, nullptr);
2862+
} else {
2863+
// Reducing from smaller size is a shift by immediate.
2864+
Type *ShiftTy = VectorType::get(
2865+
Type::getIntNTy(ValTy->getContext(), Size), 128 / Size);
2866+
ReductionCost += getArithmeticInstrCost(
2867+
Instruction::LShr, ShiftTy, TargetTransformInfo::OK_AnyValue,
2868+
TargetTransformInfo::OK_UniformConstantValue,
2869+
TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
2870+
}
2871+
2872+
// Add the arithmetic op for this level.
2873+
ReductionCost += getArithmeticInstrCost(Opcode, Ty);
28512874
}
28522875

2853-
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
2876+
// Add the final extract element to the cost.
2877+
return ReductionCost + getVectorInstrCost(Instruction::ExtractElement, Ty, 0);
28542878
}
28552879

28562880
int X86TTIImpl::getMinMaxReductionCost(Type *ValTy, Type *CondTy,

0 commit comments

Comments
 (0)