Skip to content

Commit a4dd51d

Browse files
authored
[mlir][ArithToAMDGPU] Use native packing support (#150342)
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the ___location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to. Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency. It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.
1 parent 33c9445 commit a4dd51d

File tree

3 files changed

+138
-93
lines changed

3 files changed

+138
-93
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ LogicalResult
449449
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
450450
PatternRewriter &rewriter) const {
451451
Location loc = op.getLoc();
452-
constexpr int64_t opWidth = 2;
452+
constexpr int64_t opOutWidth = 2;
453453

454454
Value in = op.getIn();
455455
Value scale = op.getScale();
@@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
460460
Type scaleType = getElementTypeOrSelf(scale);
461461
Type outType = getElementTypeOrSelf(out);
462462

463+
int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
464+
463465
VectorType outVecType = dyn_cast<VectorType>(out.getType());
464466
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
465467

@@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
473475
else if (scaleType.getIntOrFloatBitWidth() > 32)
474476
scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
475477

476-
VectorType extScaleResultType = VectorType::get(opWidth, outType);
478+
VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
477479

478480
if (!outVecType) {
479481
Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
487489

488490
VectorType inVecType = cast<VectorType>(in.getType());
489491
Value origScale = getOriginalVectorValue(op.getScale());
492+
VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
490493

491494
ArrayRef<int64_t> inShape = inVecType.getShape();
492495
SmallVector<int64_t> originalScaleShape;
493-
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
496+
if (origScaleVecType)
494497
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
495498

496499
originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
524527
Value blockResult =
525528
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
526529

527-
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
530+
for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
528531
i < blockSize;
529-
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
530-
Value slice = vector::ExtractStridedSliceOp::create(
531-
rewriter, loc, block1D, i, sliceWidth, 1);
532-
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
533-
Value scaleExt = amdgpu::ScaledExtPackedOp::create(
534-
rewriter, loc, extScaleResultType, slice, uniformScale, 0);
535-
if (sliceWidth != opWidth)
536-
scaleExt = vector::ExtractStridedSliceOp::create(
537-
rewriter, loc, scaleExt, 0, sliceWidth, 1);
538-
blockResult = vector::InsertStridedSliceOp::create(
539-
rewriter, loc, scaleExt, blockResult, i, 1);
532+
i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
533+
Value inSlice = vector::ExtractStridedSliceOp::create(
534+
rewriter, loc, block1D, i, inSliceWidth, 1);
535+
for (int64_t j = 0,
536+
outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
537+
j < inSliceWidth; j += outSliceWidth,
538+
outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
539+
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
540+
Value scaleExt = amdgpu::ScaledExtPackedOp::create(
541+
rewriter, loc, extScaleResultType, inSlice, uniformScale,
542+
j / opOutWidth);
543+
if (outSliceWidth < opOutWidth) {
544+
scaleExt = vector::ExtractStridedSliceOp::create(
545+
rewriter, loc, scaleExt, 0, outSliceWidth, 1);
546+
}
547+
blockResult = vector::InsertStridedSliceOp::create(
548+
rewriter, loc, scaleExt, blockResult, i + j, 1);
549+
}
540550
}
541551

542552
VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +565,7 @@ LogicalResult
555565
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
556566
PatternRewriter &rewriter) const {
557567
Location loc = op.getLoc();
558-
constexpr int64_t opWidth = 2;
568+
constexpr int64_t opInWidth = 2;
559569

560570
Value in = op.getIn();
561571
Value scale = op.getScale();
@@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
568578

569579
VectorType outVecType = dyn_cast<VectorType>(out.getType());
570580
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
571-
572581
if (outVecType && outVecType.isScalable())
573582
return failure();
574583

@@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
581590

582591
Value zero = arith::ConstantOp::create(rewriter, loc, outType,
583592
rewriter.getFloatAttr(outType, 0.0));
584-
unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
585-
VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
593+
int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
594+
VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
586595

587596
if (!outVecType) {
588597
Type inVecType = VectorType::get(1, inType);
@@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
598607

599608
VectorType inVecType = cast<VectorType>(in.getType());
600609
Value origScale = getOriginalVectorValue(op.getScale());
610+
VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
601611

602612
ArrayRef<int64_t> inShape = inVecType.getShape();
603-
SmallVector<int64_t> originalScaleShape;
604-
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
605-
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
613+
SmallVector<int64_t> scaleShape;
614+
if (origScaleVecType)
615+
llvm::append_range(scaleShape, origScaleVecType.getShape());
606616

607-
originalScaleShape.insert(originalScaleShape.end(),
608-
inShape.size() - originalScaleShape.size(), 1);
617+
scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
609618

610-
auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
619+
auto maybeRatio = computeShapeRatio(inShape, scaleShape);
611620
assert(maybeRatio &&
612621
"failed to derive block size from broadcast or splat operation");
613622

@@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
633642
Value blockResult =
634643
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
635644

636-
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
637-
i < blockSize;
638-
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
639-
Value slice = vector::ExtractStridedSliceOp::create(
640-
rewriter, loc, block1D, i, sliceWidth, 1);
641-
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
642-
Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
643-
rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
644-
/*existing=*/nullptr);
645-
int64_t packedWidth =
646-
cast<VectorType>(scaleTrunc.getType()).getNumElements();
647-
if (packedWidth != opWidth)
645+
for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
646+
i < blockSize; i += outSliceWidth,
647+
outSliceWidth = std::min(opOutWidth, blockSize - i)) {
648+
Value scaleTrunc;
649+
// Case where <= 2 elements are being truncated.
650+
if (outSliceWidth <= opInWidth) {
651+
Value slice = vector::ExtractStridedSliceOp::create(
652+
rewriter, loc, block1D, i, outSliceWidth, 1);
653+
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
654+
scaleTrunc = amdgpu::PackedScaledTruncOp::create(
655+
rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
656+
/*existing=*/nullptr);
657+
} else {
658+
scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
659+
truncScaleResultType, zero);
660+
for (int64_t j = 0,
661+
inSliceWidth = std::min(opInWidth, outSliceWidth - j);
662+
j < outSliceWidth; j += opInWidth,
663+
inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
664+
Value slice = vector::ExtractStridedSliceOp::create(
665+
rewriter, loc, block1D, i + j, inSliceWidth, 1);
666+
scaleTrunc = amdgpu::PackedScaledTruncOp::create(
667+
rewriter, loc, truncScaleResultType, slice, uniformScale,
668+
j / opInWidth, scaleTrunc);
669+
}
670+
}
671+
if (outSliceWidth != opOutWidth) {
648672
scaleTrunc = vector::ExtractStridedSliceOp::create(
649-
rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
673+
rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
674+
}
650675
blockResult = vector::InsertStridedSliceOp::create(
651676
rewriter, loc, scaleTrunc, blockResult, i, 1);
652677
}

mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
163163
// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
164164
// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
165165
// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
166-
// CHECK-NEXT: vector.shape_cast
166+
// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast
167167
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
168-
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
169-
// CHECK-NEXT: amdgpu.scaled_ext_packed
170-
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
171-
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
172-
// CHECK-NEXT: amdgpu.scaled_ext_packed
173-
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
168+
// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
169+
// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
170+
// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
171+
// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
174172
// CHECK-NEXT: vector.shape_cast
175173
// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
176174
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
177175
// CHECK-NEXT: vector.shape_cast
178176
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
179-
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
180177
// CHECK-NEXT: amdgpu.scaled_ext_packed
181178
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
182-
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
183179
// CHECK-NEXT: amdgpu.scaled_ext_packed
184180
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
185181
// CHECK-NEXT: vector.shape_cast
186-
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
182+
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
187183
func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
188184
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
189185
%cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
203199
// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
204200
// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
205201
// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
206-
// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
207-
// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
202+
// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
208203
// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
209-
// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
210-
// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
204+
// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
211205
// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
212206
// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
213207
// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
214208
// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
215209
// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
216-
// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
217-
// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
210+
// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
218211
// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
219-
// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
220-
// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
212+
// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
221213
// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
222214
// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
223215
// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
236228
// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
237229
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
238230
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
239-
// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
240-
// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
231+
// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
241232
// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
242-
// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
243-
// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
233+
// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
244234
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
245235
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32>
246236
func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,27 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
261251
%ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
262252
return %ext : f32
263253
}
254+
255+
// -----
256+
257+
// CHECK-LABEL: @long_fp4_broadcast
258+
// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
259+
// CHECK-NOT: amdgpu.scaled_ext_packed
260+
// CHECK: return
261+
func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
262+
%splat = vector.broadcast %scale : f32 to vector<32xf32>
263+
%ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
264+
return %ext : vector<32xf32>
265+
}
266+
267+
// -----
268+
269+
// CHECK-LABEL: @long_fp8_broadcast
270+
// CHECK-COUNT-8: amdgpu.scaled_ext_packed %{{.+}}[1]
271+
// CHECK-NOT: amdgpu.scaled_ext_packed
272+
// CHECK: return
273+
func.func @long_fp8_broadcast(%in: vector<32xf8E4M3FN>, %scale: f32) -> vector<32xf32> {
274+
%splat = vector.broadcast %scale : f32 to vector<32xf32>
275+
%ext = arith.scaling_extf %in, %splat : vector<32xf8E4M3FN>, vector<32xf32> to vector<32xf32>
276+
return %ext : vector<32xf32>
277+
}

0 commit comments

Comments
 (0)