@@ -37809,21 +37809,14 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
37809
37809
if (!Subtarget.hasSSE2())
37810
37810
return SDValue();
37811
37811
37812
- // Verify the type we're extracting from is any integer type above i16.
37813
- EVT VT = Extract->getOperand(0).getValueType();
37814
- if (!VT.isSimple() || !(VT.getVectorElementType().getSizeInBits() > 16))
37812
+ EVT ExtractVT = Extract->getValueType(0);
37813
+ // Verify the type we're extracting is either i32 or i64.
37814
+ // FIXME: Could support other types, but this is what we have coverage for.
37815
+ if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
37815
37816
return SDValue();
37816
37817
37817
- unsigned RegSize = 128;
37818
- if (Subtarget.useBWIRegs())
37819
- RegSize = 512;
37820
- else if (Subtarget.hasAVX())
37821
- RegSize = 256;
37822
-
37823
- // We handle upto v16i* for SSE2 / v32i* for AVX / v64i* for AVX512.
37824
- // TODO: We should be able to handle larger vectors by splitting them before
37825
- // feeding them into several SADs, and then reducing over those.
37826
- if (RegSize / VT.getVectorNumElements() < 8)
37818
+ EVT VT = Extract->getOperand(0).getValueType();
37819
+ if (!isPowerOf2_32(VT.getVectorNumElements()))
37827
37820
return SDValue();
37828
37821
37829
37822
// Match shuffle + add pyramid.
@@ -37839,8 +37832,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
37839
37832
// (extends the sign bit which is zero).
37840
37833
// So it is correct to skip the sign/zero extend instruction.
37841
37834
if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND ||
37842
- Root.getOpcode() == ISD::ZERO_EXTEND ||
37843
- Root.getOpcode() == ISD::ANY_EXTEND))
37835
+ Root.getOpcode() == ISD::ZERO_EXTEND ||
37836
+ Root.getOpcode() == ISD::ANY_EXTEND))
37844
37837
Root = Root.getOperand(0);
37845
37838
37846
37839
// If there was a match, we want Root to be a select that is the root of an
@@ -37860,7 +37853,7 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
37860
37853
// If the original vector was wider than 8 elements, sum over the results
37861
37854
// in the SAD vector.
37862
37855
unsigned Stages = Log2_32(VT.getVectorNumElements());
37863
- MVT SadVT = SAD.getSimpleValueType ();
37856
+ EVT SadVT = SAD.getValueType ();
37864
37857
if (Stages > 3) {
37865
37858
unsigned SadElems = SadVT.getVectorNumElements();
37866
37859
@@ -37875,12 +37868,12 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
37875
37868
}
37876
37869
}
37877
37870
37878
- MVT Type = Extract->getSimpleValueType(0 );
37879
- unsigned TypeSizeInBits = Type.getSizeInBits();
37880
- // Return the lowest TypeSizeInBits bits.
37881
- MVT ResVT = MVT::getVectorVT(Type, SadVT.getSizeInBits() / TypeSizeInBits );
37871
+ unsigned ExtractSizeInBits = ExtractVT.getSizeInBits( );
37872
+ // Return the lowest ExtractSizeInBits bits.
37873
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), ExtractVT,
37874
+ SadVT.getSizeInBits() / ExtractSizeInBits );
37882
37875
SAD = DAG.getBitcast(ResVT, SAD);
37883
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Type , SAD,
37876
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT , SAD,
37884
37877
Extract->getOperand(1));
37885
37878
}
37886
37879
0 commit comments