|
| 1 | +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s |
| 2 | + |
| 3 | +func @main() { |
| 4 | + %data = alloc() : memref<2x6xf32> |
| 5 | + %sum = alloc() : memref<2xf32> |
| 6 | + %mul = alloc() : memref<2xf32> |
| 7 | + %cst0 = constant 0.0 : f32 |
| 8 | + %cst1 = constant 1.0 : f32 |
| 9 | + %cst2 = constant 2.0 : f32 |
| 10 | + %cst4 = constant 4.0 : f32 |
| 11 | + %cst8 = constant 8.0 : f32 |
| 12 | + %cst16 = constant 16.0 : f32 |
| 13 | + |
| 14 | + %cst3 = constant 3.0 : f32 |
| 15 | + %cst6 = constant 6.0 : f32 |
| 16 | + %cst7 = constant 7.0 : f32 |
| 17 | + %cst10 = constant 10.0 : f32 |
| 18 | + %cst11 = constant 11.0 : f32 |
| 19 | + |
| 20 | + %c0 = constant 0 : index |
| 21 | + %c1 = constant 1 : index |
| 22 | + %c2 = constant 2 : index |
| 23 | + %c3 = constant 3 : index |
| 24 | + %c4 = constant 4 : index |
| 25 | + %c5 = constant 5 : index |
| 26 | + %c6 = constant 6 : index |
| 27 | + |
| 28 | + store %cst0, %data[%c0, %c0] : memref<2x6xf32> |
| 29 | + store %cst1, %data[%c0, %c1] : memref<2x6xf32> |
| 30 | + store %cst2, %data[%c0, %c2] : memref<2x6xf32> |
| 31 | + store %cst4, %data[%c0, %c3] : memref<2x6xf32> |
| 32 | + store %cst8, %data[%c0, %c4] : memref<2x6xf32> |
| 33 | + store %cst16, %data[%c0, %c5] : memref<2x6xf32> |
| 34 | + |
| 35 | + store %cst2, %data[%c1, %c0] : memref<2x6xf32> |
| 36 | + store %cst3, %data[%c1, %c1] : memref<2x6xf32> |
| 37 | + store %cst6, %data[%c1, %c2] : memref<2x6xf32> |
| 38 | + store %cst7, %data[%c1, %c3] : memref<2x6xf32> |
| 39 | + store %cst10, %data[%c1, %c4] : memref<2x6xf32> |
| 40 | + store %cst11, %data[%c1, %c5] : memref<2x6xf32> |
| 41 | + |
| 42 | + // ADD + MUL |
| 43 | + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) |
| 44 | + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { |
| 45 | + %val = load %data[%bx, %tx] : memref<2x6xf32> |
| 46 | + %reduced0 = "gpu.all_reduce"(%val) ({}) { op = "add" } : (f32) -> (f32) |
| 47 | + store %reduced0, %sum[%bx] : memref<2xf32> |
| 48 | + %reduced1 = "gpu.all_reduce"(%val) ({}) { op = "mul" } : (f32) -> (f32) |
| 49 | + store %reduced1, %mul[%bx] : memref<2xf32> |
| 50 | + gpu.terminator |
| 51 | + } |
| 52 | + |
| 53 | + %ptr_sum = memref_cast %sum : memref<2xf32> to memref<*xf32> |
| 54 | + call @print_memref_f32(%ptr_sum) : (memref<*xf32>) -> () |
| 55 | + // CHECK: [31, 39] |
| 56 | + |
| 57 | + %ptr_mul = memref_cast %mul : memref<2xf32> to memref<*xf32> |
| 58 | + call @print_memref_f32(%ptr_mul) : (memref<*xf32>) -> () |
| 59 | + // CHECK: [0, 27720] |
| 60 | + |
| 61 | + return |
| 62 | +} |
| 63 | + |
| 64 | +func @print_memref_f32(memref<*xf32>) |
0 commit comments