@@ -228,13 +228,13 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
228
228
B.buildConstant (LLT::scalar (64 ), -static_cast <int64_t >(MinOffset)));
229
229
}
230
230
231
- // Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot (x, y))
232
- // Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add(udot (x, y))
233
- // Or vecreduce_add(ext(x)) -> vecreduce_add(udot (x, 1))
231
+ // Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add([us]dot (x, y))
232
+ // Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add([us]dot (x, y))
233
+ // Or vecreduce_add(ext(x)) -> vecreduce_add([us]dot (x, 1))
234
234
// Similar to performVecReduceAddCombine in SelectionDAG
235
- bool matchExtAddvToUdotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
236
- const AArch64Subtarget &STI,
237
- std::tuple<Register, Register, bool > &MatchInfo) {
235
+ bool matchExtAddvToDotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
236
+ const AArch64Subtarget &STI,
237
+ std::tuple<Register, Register, bool > &MatchInfo) {
238
238
assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
239
239
" Expected a G_VECREDUCE_ADD instruction" );
240
240
assert (STI.hasDotProd () && " Target should have Dot Product feature" );
@@ -247,8 +247,8 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
247
247
if (DstTy.getScalarSizeInBits () != 32 || MidTy.getScalarSizeInBits () != 32 )
248
248
return false ;
249
249
250
- // Detect mul(ext, ext) with symetric ext's. If I1Opc is G_ZEXT or G_SEXT then
251
- // the ext's must match the same opcode. It is set to the ext opcode on
250
+ // Detect mul(ext, ext) with symmetric ext's. If I1Opc is G_ZEXT or G_SEXT
251
+ // then the ext's must match the same opcode. It is set to the ext opcode on
252
252
// output.
253
253
auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
254
254
Register &Out2, unsigned &I1Opc) {
@@ -315,11 +315,11 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
315
315
return true ;
316
316
}
317
317
318
- void applyExtAddvToUdotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
319
- MachineIRBuilder &Builder,
320
- GISelChangeObserver &Observer,
321
- const AArch64Subtarget &STI,
322
- std::tuple<Register, Register, bool > &MatchInfo) {
318
+ void applyExtAddvToDotAddv (MachineInstr &MI, MachineRegisterInfo &MRI,
319
+ MachineIRBuilder &Builder,
320
+ GISelChangeObserver &Observer,
321
+ const AArch64Subtarget &STI,
322
+ std::tuple<Register, Register, bool > &MatchInfo) {
323
323
assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
324
324
" Expected a G_VECREDUCE_ADD instruction" );
325
325
assert (STI.hasDotProd () && " Target should have Dot Product feature" );
@@ -581,14 +581,14 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
581
581
}
582
582
583
583
// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
584
- // extend instruction at the end by allowing selection of {s|u}addl sooner i32
585
- // add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
584
+ // extend instruction at the end by allowing selection of {s|u}addl sooner
585
+ // i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
586
586
bool matchPushAddSubExt (MachineInstr &MI, MachineRegisterInfo &MRI,
587
587
Register DstReg, Register SrcReg1, Register SrcReg2) {
588
588
assert ((MI.getOpcode () == TargetOpcode::G_ADD ||
589
589
MI.getOpcode () == TargetOpcode::G_SUB ||
590
590
MI.getOpcode () == TargetOpcode::G_MUL) &&
591
- " Expected a G_ADD or G_SUB instruction\n " );
591
+ " Expected a G_ADD, G_SUB or G_MUL instruction\n " );
592
592
593
593
// Deal with vector types only
594
594
LLT DstTy = MRI.getType (DstReg);
@@ -623,7 +623,8 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
623
623
// G_SUB has to sign-extend the result.
624
624
// G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
625
625
// needs to use the original opcode so the original opcode is used for both.
626
- if (MI.getOpcode () != TargetOpcode::G_SUB)
626
+ if (MI.getOpcode () == TargetOpcode::G_ADD ||
627
+ MI.getOpcode () == TargetOpcode::G_MUL)
627
628
B.buildInstr (Opc, {DstReg}, {AddReg});
628
629
else
629
630
B.buildSExt (DstReg, AddReg);
0 commit comments