@@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
395
395
// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
396
396
// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
397
397
// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
398
+
399
+ func @promote_first_subview_matmul (%arg0: memref <?x?xf32 , offset : ?, strides : [?, 1 ]>,
400
+ %arg1: memref <?x?xf32 , offset : ?, strides : [?, 1 ]>,
401
+ %arg2: memref <?x?xf32 , offset : ?, strides : [?, 1 ]>) {
402
+ %c2000 = constant 2000 : index
403
+ %c3000 = constant 3000 : index
404
+ %c4000 = constant 4000 : index
405
+ %c0 = constant 0 : index
406
+ %c1 = constant 1 : index
407
+ %0 = dim %arg0 , 0 : memref <?x?xf32 , offset : ?, strides : [?, 1 ]>
408
+ %1 = dim %arg0 , 1 : memref <?x?xf32 , offset : ?, strides : [?, 1 ]>
409
+ %2 = dim %arg1 , 1 : memref <?x?xf32 , offset : ?, strides : [?, 1 ]>
410
+ loop.for %arg3 = %c0 to %0 step %c2000 {
411
+ loop.for %arg4 = %c0 to %2 step %c3000 {
412
+ loop.for %arg5 = %c0 to %1 step %c4000 {
413
+ %3 = std.subview %arg0 [%arg3 , %arg5 ][%c2000 , %c4000 ][%c1 , %c1 ] :
414
+ memref <?x?xf32 , offset : ?, strides : [?, 1 ]> to memref <?x?xf32 , offset : ?, strides : [?, ?]>
415
+ %4 = std.subview %arg1 [%arg5 , %arg4 ][%c4000 , %c3000 ][%c1 , %c1 ] :
416
+ memref <?x?xf32 , offset : ?, strides : [?, 1 ]> to memref <?x?xf32 , offset : ?, strides : [?, ?]>
417
+ %5 = std.subview %arg2 [%arg3 , %arg4 ][%c2000 , %c3000 ][%c1 , %c1 ] :
418
+ memref <?x?xf32 , offset : ?, strides : [?, 1 ]> to memref <?x?xf32 , offset : ?, strides : [?, ?]>
419
+ linalg.matmul (%3 , %4 , %5 ) {__internal_linalg_transform__ = " _promote_first_view_" } :
420
+ memref <?x?xf32 , offset : ?, strides : [?, ?]>,
421
+ memref <?x?xf32 , offset : ?, strides : [?, ?]>,
422
+ memref <?x?xf32 , offset : ?, strides : [?, ?]>
423
+ }
424
+ }
425
+ }
426
+ return
427
+ }
428
+ // CHECK-LABEL: func @promote_first_subview_matmul
429
+ // CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c2000 {
430
+ // CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c3000 {
431
+ // CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c4000 {
432
+ // CHECK: %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
433
+ // CHECK: %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
434
+ // CHECK: %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
435
+ // CHECK: %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
436
+ // CHECK: %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
437
+ // CHECK: %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map:.*]]>
438
+ // CHECK-NOT: %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
439
+ // CHECK-NOT: %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
440
+ // CHECK-NOT: %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
441
+ // CHECK-NOT: %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
442
+ // CHECK-NOT: %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
443
+ // CHECK-NOT: %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
444
+ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
445
+ // CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
446
+ // CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
447
+ // CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[map]]>, memref<?x?xf32, #[[map]]>
0 commit comments