@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
628
628
As compared to prefetch_nd, which works on non-scattered TensorDesc,
629
629
it works on scattered TensorDesc instead.
630
630
631
- Example:
631
+ Example 1 :
632
632
```mlir
633
633
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
634
634
l2_hint = #xegpu.cache_hint<cached>,
635
635
l3_hint = #xegpu.cache_hint<cached>}
636
636
: !xegpu.tensor_desc<16xf16>
637
637
```
638
+
639
+ Example 2:
640
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
641
+ It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
642
+ The source operand could be a raw pointer (uint64_t).
643
+ Please refer to create_tdesc for the restriction of memref.
644
+ ```mlir
645
+ %a = memref.alloc() : memref<1024xf32>
646
+ %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
647
+ xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
648
+ l2_hint = #xegpu.cache_hint<cached>,
649
+ l3_hint = #xegpu.cache_hint<cached>}
650
+ : memref<1024xf32>, vector<4xindex>
651
+ ```
638
652
639
653
}];
640
654
641
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
655
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
656
+ Optional<XeGPU_OffsetType>: $offsets,
642
657
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
643
658
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
644
659
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
645
660
646
661
let extraClassDeclaration = extraBaseClassDeclaration # [{
662
+ Type getSourceType() {
663
+ return getSource().getType();
664
+ }
665
+
666
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
667
+ if (auto tdescType = getTensorDescType()) {
668
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
669
+ }
670
+ return TypedValue<xegpu::TensorDescType>();
671
+ }
672
+
647
673
xegpu::TensorDescType getTensorDescType() {
648
- return getTensorDesc().getType( );
674
+ return dyn_cast<xegpu::TensorDescType>(getSourceType() );
649
675
}
650
676
}];
651
677
652
- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
678
+ let assemblyFormat = [{
679
+ $source
680
+ (`[` $offsets^ `]`)?
681
+ prop-dict
682
+ attr-dict `:` type(operands)
683
+ }];
684
+
685
+ let builders = [
686
+ OpBuilder<(ins "Value": $source,
687
+ "xegpu::CachePolicyAttr": $l1_hint,
688
+ "xegpu::CachePolicyAttr": $l2_hint,
689
+ "xegpu::CachePolicyAttr": $l3_hint)>
690
+ ];
653
691
654
692
let hasVerifier = 1;
655
693
}
656
694
657
- def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
658
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
659
- ]> {
695
+ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
660
696
let summary = "load a set of scattered data points from memory.";
661
697
662
698
let description = [{ It (aka. load) load data per each work-item. The output
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
687
723
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
688
724
vector<16xi1> -> vector<16x8xf32>
689
725
```
726
+
690
727
Example 3 (SIMT mode):
691
728
```mlir
692
729
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
695
732
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
696
733
vector<16xi1> -> vector<8xf32>
697
734
```
735
+
736
+ Example 4:
737
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
738
+ It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
739
+ The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
740
+ for the restriction of memref.
741
+ ```mlir
742
+ %a = memref.alloc() : memref<1024xf32>
743
+ %offsets = vector.step : vector<16xindex>
744
+ %mask = vector.constant_mask [16]: vector<16xi1>
745
+ %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
746
+ l2_hint = #xegpu.cache_hint<cached>,
747
+ l3_hint = #xegpu.cache_hint<cached>}
748
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
749
+ ```
698
750
699
751
}];
700
752
701
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
753
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
754
+ Optional<XeGPU_OffsetType>: $offsets,
702
755
XeGPU_MaskType: $mask,
756
+ OptionalAttr<I64Attr>: $chunk_size,
703
757
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
704
758
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
705
759
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
706
760
let results = (outs XeGPU_ValueType: $value);
707
761
708
762
let extraClassDeclaration = extraBaseClassDeclaration # [{
763
+
764
+ Type getSourceType() {
765
+ return getSource().getType();
766
+ }
767
+
768
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
769
+ if (auto tdescType = getTensorDescType()) {
770
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
771
+ }
772
+ return TypedValue<xegpu::TensorDescType>();
773
+ }
774
+
709
775
xegpu::TensorDescType getTensorDescType() {
710
- return getTensorDesc().getType( );
776
+ return dyn_cast<xegpu::TensorDescType>(getSourceType() );
711
777
}
712
778
713
779
mlir::Type getElementType() {
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
725
791
726
792
}];
727
793
728
- let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
729
- `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
794
+ let assemblyFormat = [{
795
+ $source
796
+ (`[` $offsets^ `]`)? `,`
797
+ $mask prop-dict
798
+ attr-dict `:` type(operands) `->` type($value)
799
+ }];
800
+
801
+ let builders = [
802
+ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
803
+ "xegpu::CachePolicyAttr": $l1_hint,
804
+ "xegpu::CachePolicyAttr": $l2_hint,
805
+ "xegpu::CachePolicyAttr": $l3_hint)>
806
+ ];
730
807
731
808
let hasVerifier = 1;
732
809
}
733
810
734
- def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
735
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
736
- ]> {
811
+ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
737
812
let summary = "store data to scattered memory locations.";
738
813
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
739
814
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
768
843
l3_hint = #xegpu.cache_hint<write_through>}>
769
844
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
770
845
```
846
+
847
+ Example 4:
848
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
849
+ It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
850
+ The dest operand could be a raw pointer (uint64_t).
851
+ Please refer to create_tdesc for the restriction of memref.
852
+ ```mlir
853
+ %a = memref.alloc() : memref<1024xf32>
854
+ %val = arith.constant dense<0.0> : vector<16xf32>
855
+ %offsets = vector.step : vector<16xindex>
856
+ %mask = vector.constant_mask [16]: vector<16xi1>
857
+ xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
858
+ l2_hint = #xegpu.cache_hint<cached>,
859
+ l3_hint = #xegpu.cache_hint<cached>}
860
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
861
+ ```
862
+
771
863
}];
772
864
773
865
let arguments = (ins
774
866
XeGPU_ValueType: $value,
775
- XeGPU_TensorDesc: $TensorDesc,
867
+ XeGPU_GatherScatterSourceType: $dest,
868
+ Optional<XeGPU_OffsetType>: $offsets,
776
869
XeGPU_MaskType: $mask,
870
+ OptionalAttr<I64Attr>: $chunk_size,
777
871
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
778
872
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
779
873
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
780
874
781
875
let extraClassDeclaration = extraBaseClassDeclaration # [{
876
+ Type getDestType() {
877
+ return getDest().getType();
878
+ }
879
+
880
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
881
+ if (auto tdescType = getTensorDescType()) {
882
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
883
+ }
884
+ return TypedValue<xegpu::TensorDescType>();
885
+ }
886
+
782
887
xegpu::TensorDescType getTensorDescType() {
783
- return getTensorDesc().getType( );
888
+ return dyn_cast<xegpu::TensorDescType>(getDestType() );
784
889
}
785
890
786
891
VectorType getValueType() {
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
792
897
}
793
898
}];
794
899
795
- let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
796
- `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
900
+ let assemblyFormat = [{
901
+ $value `,`
902
+ $dest
903
+ (`[` $offsets^ `]`)? `,`
904
+ $mask
905
+ prop-dict
906
+ attr-dict `:` type(operands)
907
+ }];
908
+
909
+ let builders = [
910
+ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
911
+ "xegpu::CachePolicyAttr": $l1_hint,
912
+ "xegpu::CachePolicyAttr": $l2_hint,
913
+ "xegpu::CachePolicyAttr": $l3_hint)>
914
+ ];
797
915
798
916
let hasVerifier = 1;
799
917
}
0 commit comments