@@ -449,7 +449,7 @@ LogicalResult
449
449
ScalingExtFRewritePattern::matchAndRewrite (arith::ScalingExtFOp op,
450
450
PatternRewriter &rewriter) const {
451
451
Location loc = op.getLoc ();
452
- constexpr int64_t opWidth = 2 ;
452
+ constexpr int64_t opOutWidth = 2 ;
453
453
454
454
Value in = op.getIn ();
455
455
Value scale = op.getScale ();
@@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
460
460
Type scaleType = getElementTypeOrSelf (scale);
461
461
Type outType = getElementTypeOrSelf (out);
462
462
463
+ int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth ();
464
+
463
465
VectorType outVecType = dyn_cast<VectorType>(out.getType ());
464
466
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
465
467
@@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
473
475
else if (scaleType.getIntOrFloatBitWidth () > 32 )
474
476
scale = arith::TruncFOp::create (rewriter, loc, scaleF32Type, scale);
475
477
476
- VectorType extScaleResultType = VectorType::get (opWidth , outType);
478
+ VectorType extScaleResultType = VectorType::get (opOutWidth , outType);
477
479
478
480
if (!outVecType) {
479
481
Value inCast = vector::BroadcastOp::create (rewriter, loc,
@@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
487
489
488
490
VectorType inVecType = cast<VectorType>(in.getType ());
489
491
Value origScale = getOriginalVectorValue (op.getScale ());
492
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType ());
490
493
491
494
ArrayRef<int64_t > inShape = inVecType.getShape ();
492
495
SmallVector<int64_t > originalScaleShape;
493
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale. getType ()) )
496
+ if (origScaleVecType)
494
497
llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
495
498
496
499
originalScaleShape.insert (originalScaleShape.end (),
@@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
524
527
Value blockResult =
525
528
rewriter.createOrFold <vector::BroadcastOp>(loc, blockResultType, zero);
526
529
527
- for (int64_t i = 0 , sliceWidth = std::min (opWidth , blockSize - i);
530
+ for (int64_t i = 0 , inSliceWidth = std::min (opInWidth , blockSize - i);
528
531
i < blockSize;
529
- i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
530
- Value slice = vector::ExtractStridedSliceOp::create (
531
- rewriter, loc, block1D, i, sliceWidth, 1 );
532
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
533
- Value scaleExt = amdgpu::ScaledExtPackedOp::create (
534
- rewriter, loc, extScaleResultType, slice, uniformScale, 0 );
535
- if (sliceWidth != opWidth)
536
- scaleExt = vector::ExtractStridedSliceOp::create (
537
- rewriter, loc, scaleExt, 0 , sliceWidth, 1 );
538
- blockResult = vector::InsertStridedSliceOp::create (
539
- rewriter, loc, scaleExt, blockResult, i, 1 );
532
+ i += inSliceWidth, inSliceWidth = std::min (opInWidth, blockSize - i)) {
533
+ Value inSlice = vector::ExtractStridedSliceOp::create (
534
+ rewriter, loc, block1D, i, inSliceWidth, 1 );
535
+ for (int64_t j = 0 ,
536
+ outSliceWidth = std::min (opOutWidth, inSliceWidth - j);
537
+ j < inSliceWidth; j += outSliceWidth,
538
+ outSliceWidth = std::min (opOutWidth, inSliceWidth - j)) {
539
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
540
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create (
541
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
542
+ j / opOutWidth);
543
+ if (outSliceWidth < opOutWidth) {
544
+ scaleExt = vector::ExtractStridedSliceOp::create (
545
+ rewriter, loc, scaleExt, 0 , outSliceWidth, 1 );
546
+ }
547
+ blockResult = vector::InsertStridedSliceOp::create (
548
+ rewriter, loc, scaleExt, blockResult, i + j, 1 );
549
+ }
540
550
}
541
551
542
552
VectorType resultType = VectorType::get (ratio, outType);
@@ -555,7 +565,7 @@ LogicalResult
555
565
ScalingTruncFRewritePattern::matchAndRewrite (arith::ScalingTruncFOp op,
556
566
PatternRewriter &rewriter) const {
557
567
Location loc = op.getLoc ();
558
- constexpr int64_t opWidth = 2 ;
568
+ constexpr int64_t opInWidth = 2 ;
559
569
560
570
Value in = op.getIn ();
561
571
Value scale = op.getScale ();
@@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
568
578
569
579
VectorType outVecType = dyn_cast<VectorType>(out.getType ());
570
580
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
571
-
572
581
if (outVecType && outVecType.isScalable ())
573
582
return failure ();
574
583
@@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
581
590
582
591
Value zero = arith::ConstantOp::create (rewriter, loc, outType,
583
592
rewriter.getFloatAttr (outType, 0.0 ));
584
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth ();
585
- VectorType truncScaleResultType = VectorType::get (numPackedElem , outType);
593
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth ();
594
+ VectorType truncScaleResultType = VectorType::get (opOutWidth , outType);
586
595
587
596
if (!outVecType) {
588
597
Type inVecType = VectorType::get (1 , inType);
@@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
598
607
599
608
VectorType inVecType = cast<VectorType>(in.getType ());
600
609
Value origScale = getOriginalVectorValue (op.getScale ());
610
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType ());
601
611
602
612
ArrayRef<int64_t > inShape = inVecType.getShape ();
603
- SmallVector<int64_t > originalScaleShape ;
604
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale. getType ()) )
605
- llvm::append_range (originalScaleShape , origScaleVecType.getShape ());
613
+ SmallVector<int64_t > scaleShape ;
614
+ if (origScaleVecType)
615
+ llvm::append_range (scaleShape , origScaleVecType.getShape ());
606
616
607
- originalScaleShape.insert (originalScaleShape.end (),
608
- inShape.size () - originalScaleShape.size (), 1 );
617
+ scaleShape.insert (scaleShape.end (), inShape.size () - scaleShape.size (), 1 );
609
618
610
- auto maybeRatio = computeShapeRatio (inShape, originalScaleShape );
619
+ auto maybeRatio = computeShapeRatio (inShape, scaleShape );
611
620
assert (maybeRatio &&
612
621
" failed to derive block size from broadcast or splat operation" );
613
622
@@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
633
642
Value blockResult =
634
643
rewriter.createOrFold <vector::BroadcastOp>(loc, blockResultType, zero);
635
644
636
- for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
637
- i < blockSize;
638
- i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
639
- Value slice = vector::ExtractStridedSliceOp::create (
640
- rewriter, loc, block1D, i, sliceWidth, 1 );
641
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
642
- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create (
643
- rewriter, loc, truncScaleResultType, slice, uniformScale, 0 ,
644
- /* existing=*/ nullptr );
645
- int64_t packedWidth =
646
- cast<VectorType>(scaleTrunc.getType ()).getNumElements ();
647
- if (packedWidth != opWidth)
645
+ for (int64_t i = 0 , outSliceWidth = std::min (opOutWidth, blockSize - i);
646
+ i < blockSize; i += outSliceWidth,
647
+ outSliceWidth = std::min (opOutWidth, blockSize - i)) {
648
+ Value scaleTrunc;
649
+ // Case where <= 2 elements are being truncated.
650
+ if (outSliceWidth <= opInWidth) {
651
+ Value slice = vector::ExtractStridedSliceOp::create (
652
+ rewriter, loc, block1D, i, outSliceWidth, 1 );
653
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
654
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create (
655
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0 ,
656
+ /* existing=*/ nullptr );
657
+ } else {
658
+ scaleTrunc = vector::BroadcastOp::create (rewriter, loc,
659
+ truncScaleResultType, zero);
660
+ for (int64_t j = 0 ,
661
+ inSliceWidth = std::min (opInWidth, outSliceWidth - j);
662
+ j < outSliceWidth; j += opInWidth,
663
+ inSliceWidth = std::min (opInWidth, outSliceWidth - j)) {
664
+ Value slice = vector::ExtractStridedSliceOp::create (
665
+ rewriter, loc, block1D, i + j, inSliceWidth, 1 );
666
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create (
667
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
668
+ j / opInWidth, scaleTrunc);
669
+ }
670
+ }
671
+ if (outSliceWidth != opOutWidth) {
648
672
scaleTrunc = vector::ExtractStridedSliceOp::create (
649
- rewriter, loc, scaleTrunc, 0 , sliceWidth, 1 );
673
+ rewriter, loc, scaleTrunc, 0 , outSliceWidth, 1 );
674
+ }
650
675
blockResult = vector::InsertStridedSliceOp::create (
651
676
rewriter, loc, scaleTrunc, blockResult, i, 1 );
652
677
}
0 commit comments