@@ -2648,39 +2648,13 @@ int X86TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
2648
2648
2649
2649
int X86TTIImpl::getArithmeticReductionCost (unsigned Opcode, Type *ValTy,
2650
2650
bool IsPairwise) {
2651
+ // Just use the default implementation for pair reductions.
2652
+ if (IsPairwise)
2653
+ return BaseT::getArithmeticReductionCost (Opcode, ValTy, IsPairwise);
2654
+
2651
2655
// We use the Intel Architecture Code Analyzer(IACA) to measure the throughput
2652
2656
// and make it as the cost.
2653
2657
2654
- static const CostTblEntry SLMCostTblPairWise[] = {
2655
- { ISD::FADD, MVT::v2f64, 3 },
2656
- { ISD::ADD, MVT::v2i64, 5 },
2657
- };
2658
-
2659
- static const CostTblEntry SSE2CostTblPairWise[] = {
2660
- { ISD::FADD, MVT::v2f64, 2 },
2661
- { ISD::FADD, MVT::v4f32, 4 },
2662
- { ISD::ADD, MVT::v2i64, 2 }, // The data reported by the IACA tool is "1.6".
2663
- { ISD::ADD, MVT::v2i32, 2 }, // FIXME: chosen to be less than v4i32.
2664
- { ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.5".
2665
- { ISD::ADD, MVT::v2i16, 3 }, // FIXME: chosen to be less than v4i16
2666
- { ISD::ADD, MVT::v4i16, 4 }, // FIXME: chosen to be less than v8i16
2667
- { ISD::ADD, MVT::v8i16, 5 },
2668
- { ISD::ADD, MVT::v2i8, 2 },
2669
- { ISD::ADD, MVT::v4i8, 2 },
2670
- { ISD::ADD, MVT::v8i8, 2 },
2671
- { ISD::ADD, MVT::v16i8, 3 },
2672
- };
2673
-
2674
- static const CostTblEntry AVX1CostTblPairWise[] = {
2675
- { ISD::FADD, MVT::v4f64, 5 },
2676
- { ISD::FADD, MVT::v8f32, 7 },
2677
- { ISD::ADD, MVT::v2i64, 1 }, // The data reported by the IACA tool is "1.5".
2678
- { ISD::ADD, MVT::v4i64, 5 }, // The data reported by the IACA tool is "4.8".
2679
- { ISD::ADD, MVT::v8i32, 5 },
2680
- { ISD::ADD, MVT::v16i16, 6 },
2681
- { ISD::ADD, MVT::v32i8, 4 },
2682
- };
2683
-
2684
2658
static const CostTblEntry SLMCostTblNoPairWise[] = {
2685
2659
{ ISD::FADD, MVT::v2f64, 3 },
2686
2660
{ ISD::ADD, MVT::v2i64, 5 },
@@ -2721,62 +2695,44 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
2721
2695
EVT VT = TLI->getValueType (DL, ValTy);
2722
2696
if (VT.isSimple ()) {
2723
2697
MVT MTy = VT.getSimpleVT ();
2724
- if (IsPairwise) {
2725
- if (ST->isSLM ())
2726
- if (const auto *Entry = CostTableLookup (SLMCostTblPairWise, ISD, MTy))
2727
- return Entry->Cost ;
2728
-
2729
- if (ST->hasAVX ())
2730
- if (const auto *Entry = CostTableLookup (AVX1CostTblPairWise, ISD, MTy))
2731
- return Entry->Cost ;
2732
-
2733
- if (ST->hasSSE2 ())
2734
- if (const auto *Entry = CostTableLookup (SSE2CostTblPairWise, ISD, MTy))
2735
- return Entry->Cost ;
2736
- } else {
2737
- if (ST->isSLM ())
2738
- if (const auto *Entry = CostTableLookup (SLMCostTblNoPairWise, ISD, MTy))
2739
- return Entry->Cost ;
2698
+ if (ST->isSLM ())
2699
+ if (const auto *Entry = CostTableLookup (SLMCostTblNoPairWise, ISD, MTy))
2700
+ return Entry->Cost ;
2740
2701
2741
- if (ST->hasAVX ())
2742
- if (const auto *Entry = CostTableLookup (AVX1CostTblNoPairWise, ISD, MTy))
2743
- return Entry->Cost ;
2702
+ if (ST->hasAVX ())
2703
+ if (const auto *Entry = CostTableLookup (AVX1CostTblNoPairWise, ISD, MTy))
2704
+ return Entry->Cost ;
2744
2705
2745
- if (ST->hasSSE2 ())
2746
- if (const auto *Entry = CostTableLookup (SSE2CostTblNoPairWise, ISD, MTy))
2747
- return Entry->Cost ;
2748
- }
2706
+ if (ST->hasSSE2 ())
2707
+ if (const auto *Entry = CostTableLookup (SSE2CostTblNoPairWise, ISD, MTy))
2708
+ return Entry->Cost ;
2749
2709
}
2750
2710
2751
2711
std::pair<int , MVT> LT = TLI->getTypeLegalizationCost (DL, ValTy);
2752
2712
2753
2713
MVT MTy = LT.second ;
2754
2714
2755
- if (IsPairwise) {
2756
- if (ST->isSLM ())
2757
- if (const auto *Entry = CostTableLookup (SLMCostTblPairWise, ISD, MTy))
2758
- return LT.first * Entry->Cost ;
2759
-
2760
- if (ST->hasAVX ())
2761
- if (const auto *Entry = CostTableLookup (AVX1CostTblPairWise, ISD, MTy))
2762
- return LT.first * Entry->Cost ;
2715
+ unsigned ArithmeticCost = 0 ;
2716
+ if (LT.first != 1 && MTy.isVector () &&
2717
+ MTy.getVectorNumElements () < ValTy->getVectorNumElements ()) {
2718
+ // Type needs to be split. We need LT.first - 1 arithmetic ops.
2719
+ Type *SingleOpTy = VectorType::get (ValTy->getVectorElementType (),
2720
+ MTy.getVectorNumElements ());
2721
+ ArithmeticCost = getArithmeticInstrCost (Opcode, SingleOpTy);
2722
+ ArithmeticCost *= LT.first - 1 ;
2723
+ }
2763
2724
2764
- if (ST->hasSSE2 ())
2765
- if (const auto *Entry = CostTableLookup (SSE2CostTblPairWise, ISD, MTy))
2766
- return LT.first * Entry->Cost ;
2767
- } else {
2768
- if (ST->isSLM ())
2769
- if (const auto *Entry = CostTableLookup (SLMCostTblNoPairWise, ISD, MTy))
2770
- return LT.first * Entry->Cost ;
2725
+ if (ST->isSLM ())
2726
+ if (const auto *Entry = CostTableLookup (SLMCostTblNoPairWise, ISD, MTy))
2727
+ return ArithmeticCost + Entry->Cost ;
2771
2728
2772
- if (ST->hasAVX ())
2773
- if (const auto *Entry = CostTableLookup (AVX1CostTblNoPairWise, ISD, MTy))
2774
- return LT. first * Entry->Cost ;
2729
+ if (ST->hasAVX ())
2730
+ if (const auto *Entry = CostTableLookup (AVX1CostTblNoPairWise, ISD, MTy))
2731
+ return ArithmeticCost + Entry->Cost ;
2775
2732
2776
- if (ST->hasSSE2 ())
2777
- if (const auto *Entry = CostTableLookup (SSE2CostTblNoPairWise, ISD, MTy))
2778
- return LT.first * Entry->Cost ;
2779
- }
2733
+ if (ST->hasSSE2 ())
2734
+ if (const auto *Entry = CostTableLookup (SSE2CostTblNoPairWise, ISD, MTy))
2735
+ return ArithmeticCost + Entry->Cost ;
2780
2736
2781
2737
// FIXME: These assume a naive kshift+binop lowering, which is probably
2782
2738
// conservative in most cases.
@@ -2825,9 +2781,9 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
2825
2781
};
2826
2782
2827
2783
// Handle bool allof/anyof patterns.
2828
- if (!IsPairwise && ValTy->getVectorElementType ()->isIntegerTy (1 )) {
2784
+ if (ValTy->getVectorElementType ()->isIntegerTy (1 )) {
2829
2785
unsigned ArithmeticCost = 0 ;
2830
- if (MTy.isVector () &&
2786
+ if (LT. first != 1 && MTy.isVector () &&
2831
2787
MTy.getVectorNumElements () < ValTy->getVectorNumElements ()) {
2832
2788
// Type needs to be split. We need LT.first - 1 arithmetic ops.
2833
2789
Type *SingleOpTy = VectorType::get (ValTy->getVectorElementType (),
@@ -2848,9 +2804,77 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
2848
2804
if (ST->hasSSE2 ())
2849
2805
if (const auto *Entry = CostTableLookup (SSE2BoolReduction, ISD, MTy))
2850
2806
return ArithmeticCost + Entry->Cost ;
2807
+
2808
+ return BaseT::getArithmeticReductionCost (Opcode, ValTy, IsPairwise);
2809
+ }
2810
+
2811
+ unsigned NumVecElts = ValTy->getVectorNumElements ();
2812
+ unsigned ScalarSize = ValTy->getScalarSizeInBits ();
2813
+
2814
+ // Special case power of 2 reductions where the scalar type isn't changed
2815
+ // by type legalization.
2816
+ if (!isPowerOf2_32 (NumVecElts) || ScalarSize != MTy.getScalarSizeInBits ())
2817
+ return BaseT::getArithmeticReductionCost (Opcode, ValTy, IsPairwise);
2818
+
2819
+ unsigned ReductionCost = 0 ;
2820
+
2821
+ Type *Ty = ValTy;
2822
+ if (LT.first != 1 && MTy.isVector () &&
2823
+ MTy.getVectorNumElements () < ValTy->getVectorNumElements ()) {
2824
+ // Type needs to be split. We need LT.first - 1 arithmetic ops.
2825
+ Ty = VectorType::get (ValTy->getVectorElementType (),
2826
+ MTy.getVectorNumElements ());
2827
+ ReductionCost = getArithmeticInstrCost (Opcode, Ty);
2828
+ ReductionCost *= LT.first - 1 ;
2829
+ NumVecElts = MTy.getVectorNumElements ();
2830
+ }
2831
+
2832
+ // Now handle reduction with the legal type, taking into account size changes
2833
+ // at each level.
2834
+ while (NumVecElts > 1 ) {
2835
+ // Determine the size of the remaining vector we need to reduce.
2836
+ unsigned Size = NumVecElts * ScalarSize;
2837
+ NumVecElts /= 2 ;
2838
+ // If we're reducing from 256/512 bits, use an extract_subvector.
2839
+ if (Size > 128 ) {
2840
+ Type *SubTy = VectorType::get (ValTy->getVectorElementType (), NumVecElts);
2841
+ ReductionCost +=
2842
+ getShuffleCost (TTI::SK_ExtractSubvector, Ty, NumVecElts, SubTy);
2843
+ Ty = SubTy;
2844
+ } else if (Size == 128 ) {
2845
+ // Reducing from 128 bits is a permute of v2f64/v2i64.
2846
+ Type *ShufTy;
2847
+ if (ValTy->isFloatingPointTy ())
2848
+ ShufTy = VectorType::get (Type::getDoubleTy (ValTy->getContext ()), 2 );
2849
+ else
2850
+ ShufTy = VectorType::get (Type::getInt64Ty (ValTy->getContext ()), 2 );
2851
+ ReductionCost +=
2852
+ getShuffleCost (TTI::SK_PermuteSingleSrc, ShufTy, 0 , nullptr );
2853
+ } else if (Size == 64 ) {
2854
+ // Reducing from 64 bits is a shuffle of v4f32/v4i32.
2855
+ Type *ShufTy;
2856
+ if (ValTy->isFloatingPointTy ())
2857
+ ShufTy = VectorType::get (Type::getFloatTy (ValTy->getContext ()), 4 );
2858
+ else
2859
+ ShufTy = VectorType::get (Type::getInt32Ty (ValTy->getContext ()), 4 );
2860
+ ReductionCost +=
2861
+ getShuffleCost (TTI::SK_PermuteSingleSrc, ShufTy, 0 , nullptr );
2862
+ } else {
2863
+ // Reducing from smaller size is a shift by immediate.
2864
+ Type *ShiftTy = VectorType::get (
2865
+ Type::getIntNTy (ValTy->getContext (), Size), 128 / Size);
2866
+ ReductionCost += getArithmeticInstrCost (
2867
+ Instruction::LShr, ShiftTy, TargetTransformInfo::OK_AnyValue,
2868
+ TargetTransformInfo::OK_UniformConstantValue,
2869
+ TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
2870
+ }
2871
+
2872
+ // Add the arithmetic op for this level.
2873
+ ReductionCost += getArithmeticInstrCost (Opcode, Ty);
2851
2874
}
2852
2875
2853
- return BaseT::getArithmeticReductionCost (Opcode, ValTy, IsPairwise);
2876
+ // Add the final extract element to the cost.
2877
+ return ReductionCost + getVectorInstrCost (Instruction::ExtractElement, Ty, 0 );
2854
2878
}
2855
2879
2856
2880
int X86TTIImpl::getMinMaxReductionCost (Type *ValTy, Type *CondTy,
0 commit comments