Skip to content

Commit e6f360b

Browse files
authored
[MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc (#150576)
Add variant of load/store/prefetch to allow offset. The new xegpu.load variant accepts memref+offset, and the existing tdesc operand will be removed in the future PR. The semantics are combination of "creating scattered_tdesc + xegpu.load with scattered_tdesc". The current xegpu.load accepts tdesc operand, which encapsulates "memref+offset". This PR "fold" "memref+offset" directly to xegpu.load replacing "tdesc". Create_tdesc will be removed as scatter_tdesc only contains base address after offsets being taken away, so there is no point to keep it. ```mlir // wi level code example %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<2xf32> xegpu.store %val, %src[%offsets], %mask: vector<1xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> xegpu.prefetch %src[%0] : ui64, vector<1xindex> ```
1 parent b9a627e commit e6f360b

File tree

6 files changed

+338
-26
lines changed

6 files changed

+338
-26
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 136 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
628628
As compared to prefetch_nd, which works on non-scattered TensorDesc,
629629
it works on scattered TensorDesc instead.
630630

631-
Example:
631+
Example 1:
632632
```mlir
633633
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
634634
l2_hint = #xegpu.cache_hint<cached>,
635635
l3_hint = #xegpu.cache_hint<cached>}
636636
: !xegpu.tensor_desc<16xf16>
637637
```
638+
639+
Example 2:
640+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641+
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642+
The source operand could be a raw pointer (uint64_t).
643+
Please refer to create_tdesc for the restriction of memref.
644+
```mlir
645+
%a = memref.alloc() : memref<1024xf32>
646+
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
647+
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
648+
l2_hint = #xegpu.cache_hint<cached>,
649+
l3_hint = #xegpu.cache_hint<cached>}
650+
: memref<1024xf32>, vector<4xindex>
651+
```
638652

639653
}];
640654

641-
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
655+
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
656+
Optional<XeGPU_OffsetType>: $offsets,
642657
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
643658
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
644659
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
645660

646661
let extraClassDeclaration = extraBaseClassDeclaration # [{
662+
Type getSourceType() {
663+
return getSource().getType();
664+
}
665+
666+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
667+
if (auto tdescType = getTensorDescType()) {
668+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
669+
}
670+
return TypedValue<xegpu::TensorDescType>();
671+
}
672+
647673
xegpu::TensorDescType getTensorDescType() {
648-
return getTensorDesc().getType();
674+
return dyn_cast<xegpu::TensorDescType>(getSourceType());
649675
}
650676
}];
651677

652-
let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
678+
let assemblyFormat = [{
679+
$source
680+
(`[` $offsets^ `]`)?
681+
prop-dict
682+
attr-dict `:` type(operands)
683+
}];
684+
685+
let builders = [
686+
OpBuilder<(ins "Value": $source,
687+
"xegpu::CachePolicyAttr": $l1_hint,
688+
"xegpu::CachePolicyAttr": $l2_hint,
689+
"xegpu::CachePolicyAttr": $l3_hint)>
690+
];
653691

654692
let hasVerifier = 1;
655693
}
656694

657-
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
658-
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
659-
]> {
695+
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
660696
let summary = "load a set of scattered data points from memory.";
661697

662698
let description = [{ It (aka. load) load data per each work-item. The output
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
687723
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
688724
vector<16xi1> -> vector<16x8xf32>
689725
```
726+
690727
Example 3 (SIMT mode):
691728
```mlir
692729
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
695732
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
696733
vector<16xi1> -> vector<8xf32>
697734
```
735+
736+
Example 4:
737+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738+
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739+
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740+
for the restriction of memref.
741+
```mlir
742+
%a = memref.alloc() : memref<1024xf32>
743+
%offsets = vector.step : vector<16xindex>
744+
%mask = vector.constant_mask [16]: vector<16xi1>
745+
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
746+
l2_hint = #xegpu.cache_hint<cached>,
747+
l3_hint = #xegpu.cache_hint<cached>}
748+
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
749+
```
698750

699751
}];
700752

701-
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
753+
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
754+
Optional<XeGPU_OffsetType>: $offsets,
702755
XeGPU_MaskType: $mask,
756+
OptionalAttr<I64Attr>: $chunk_size,
703757
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
704758
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
705759
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
706760
let results = (outs XeGPU_ValueType: $value);
707761

708762
let extraClassDeclaration = extraBaseClassDeclaration # [{
763+
764+
Type getSourceType() {
765+
return getSource().getType();
766+
}
767+
768+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
769+
if (auto tdescType = getTensorDescType()) {
770+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
771+
}
772+
return TypedValue<xegpu::TensorDescType>();
773+
}
774+
709775
xegpu::TensorDescType getTensorDescType() {
710-
return getTensorDesc().getType();
776+
return dyn_cast<xegpu::TensorDescType>(getSourceType());
711777
}
712778

713779
mlir::Type getElementType() {
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
725791

726792
}];
727793

728-
let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
729-
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
794+
let assemblyFormat = [{
795+
$source
796+
(`[` $offsets^ `]`)? `,`
797+
$mask prop-dict
798+
attr-dict `:` type(operands) `->` type($value)
799+
}];
800+
801+
let builders = [
802+
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803+
"xegpu::CachePolicyAttr": $l1_hint,
804+
"xegpu::CachePolicyAttr": $l2_hint,
805+
"xegpu::CachePolicyAttr": $l3_hint)>
806+
];
730807

731808
let hasVerifier = 1;
732809
}
733810

734-
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
735-
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
736-
]> {
811+
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
737812
let summary = "store data to scattered memory locations.";
738813
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
739814
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
768843
l3_hint = #xegpu.cache_hint<write_through>}>
769844
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
770845
```
846+
847+
Example 4:
848+
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849+
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850+
The dest operand could be a raw pointer (uint64_t).
851+
Please refer to create_tdesc for the restriction of memref.
852+
```mlir
853+
%a = memref.alloc() : memref<1024xf32>
854+
%val = arith.constant dense<0.0> : vector<16xf32>
855+
%offsets = vector.step : vector<16xindex>
856+
%mask = vector.constant_mask [16]: vector<16xi1>
857+
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
858+
l2_hint = #xegpu.cache_hint<cached>,
859+
l3_hint = #xegpu.cache_hint<cached>}
860+
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
861+
```
862+
771863
}];
772864

773865
let arguments = (ins
774866
XeGPU_ValueType: $value,
775-
XeGPU_TensorDesc: $TensorDesc,
867+
XeGPU_GatherScatterSourceType: $dest,
868+
Optional<XeGPU_OffsetType>: $offsets,
776869
XeGPU_MaskType: $mask,
870+
OptionalAttr<I64Attr>: $chunk_size,
777871
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
778872
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
779873
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
780874

781875
let extraClassDeclaration = extraBaseClassDeclaration # [{
876+
Type getDestType() {
877+
return getDest().getType();
878+
}
879+
880+
TypedValue<xegpu::TensorDescType> getTensorDesc() {
881+
if (auto tdescType = getTensorDescType()) {
882+
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
883+
}
884+
return TypedValue<xegpu::TensorDescType>();
885+
}
886+
782887
xegpu::TensorDescType getTensorDescType() {
783-
return getTensorDesc().getType();
888+
return dyn_cast<xegpu::TensorDescType>(getDestType());
784889
}
785890

786891
VectorType getValueType() {
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
792897
}
793898
}];
794899

795-
let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
796-
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
900+
let assemblyFormat = [{
901+
$value `,`
902+
$dest
903+
(`[` $offsets^ `]`)? `,`
904+
$mask
905+
prop-dict
906+
attr-dict `:` type(operands)
907+
}];
908+
909+
let builders = [
910+
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911+
"xegpu::CachePolicyAttr": $l1_hint,
912+
"xegpu::CachePolicyAttr": $l2_hint,
913+
"xegpu::CachePolicyAttr": $l3_hint)>
914+
];
797915

798916
let hasVerifier = 1;
799917
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
189189
let genVerifyDecl = 1;
190190
}
191191

192+
def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
192193

193194
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
194195
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

0 commit comments

Comments
 (0)