Skip to content

[mlir] [affine] affine-loop-fusion incorrectly fuses loops #152077

@bubblepipe

Description

@bubblepipe

Consider the following MLIR code snippet:

func.func @foo(%input: memref<1x40x32x128xf32>, 
                                  %output: memref<1x32x40x128xf32>,
                                  %scale: memref<1x1x40x128xf32>) {
  affine.for %i0 = 0 to 1 {
    affine.for %i1 = 0 to 32 {
      affine.for %i2 = 0 to 40 {
        affine.for %i3 = 0 to 128 {
          %val = affine.load %input[%i0, %i2, %i1, %i3] : memref<1x40x32x128xf32>
          affine.store %val, %output[%i0, %i1, %i2, %i3] : memref<1x32x40x128xf32>
        }
      }
    }
  }

  affine.for %i0 = 0 to 1 {
    affine.for %i1 = 0 to 32 {
      affine.for %i2 = 0 to 40 {
        affine.for %i3 = 0 to 128 {
          %val = affine.load %output[%i0, %i1, %i2, %i3] : memref<1x32x40x128xf32>
          %scale_val = affine.load %scale[%i0, 0, %i2, %i3] : memref<1x1x40x128xf32>
          %result = arith.mulf %val, %scale_val : f32
          affine.store %result, %output[%i0, %i1, %i2, %i3] : memref<1x32x40x128xf32>
        }
      }
    }
  }
  return
}

the result produced by mlir-opt --affine-loop-fusion is:

module {
  func.func @foo(%arg0: memref<1x40x32x128xf32>, %arg1: memref<1x32x40x128xf32>, %arg2: memref<1x1x40x128xf32>) {
    %c0 = arith.constant 0 : index // ❌ should not have this const 
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 32 {
        affine.for %arg5 = 0 to 40 {
          affine.for %arg6 = 0 to 128 {
            %0 = affine.load %arg0[%c0, %arg5, %arg4, %arg6] : memref<1x40x32x128xf32> // ❌ Uses %c0 instead of %arg3
            affine.store %0, %arg1[%c0, %arg4, %arg5, %arg6] : memref<1x32x40x128xf32> // ❌ Uses %c0 instead of %arg3
            %1 = affine.load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x32x40x128xf32>
            %2 = affine.load %arg2[%arg3, 0, %arg5, %arg6] : memref<1x1x40x128xf32>
            %3 = arith.mulf %1, %2 : f32
            affine.store %3, %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x32x40x128xf32>
          }
        }
      }
    }
    return
  }
}

but the correct loop should be:

affine.for %arg3 = 0 to 1 {
  ...
  %0 = affine.load %arg0[%arg3, %arg5, %arg4, %arg6]  // Should use %arg3
  affine.store %0, %arg1[%arg3, %arg4, %arg5, %arg6]  // Should use %arg3
  %1 = affine.load %arg1[%arg3, %arg4, %arg5, %arg6]  
  ...
}

Here is a godbolt link for reproduction https://godbolt.org/z/K6nzjd4d5.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions