diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 90383265002a3..9c74cff0d14f1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4448,6 +4448,7 @@ def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended" def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>; def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; +def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>; def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; def SPIRV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPIRV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; @@ -4630,7 +4631,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, - SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, + SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite, + SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index ab535d7b2a304..9331fc576c7bd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -403,6 +403,28 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual", // ----- +def SPIRV_IsFiniteOp : SPIRV_LogicalUnaryOp<"IsFinite", SPIRV_Float, []> { + let summary = "Result is true if x is an IEEE Finite, otherwise result is false"; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + #### Example: + + ```mlir + %2 = spirv.IsFinite %0: f32 + %3 = spirv.IsFinite %1: vector<4xf32> + ``` + }]; +} + +// ----- + def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { let summary = "Result is true if x is an IEEE Inf, otherwise result is false"; @@ -418,7 +440,7 @@ def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { ```mlir %2 = spirv.IsInf %0: f32 - %3 = spirv.IsInf %1: vector<4xi32> + %3 = spirv.IsInf %1: vector<4xf32> ``` }]; } @@ -442,7 +464,7 @@ def SPIRV_IsNanOp : SPIRV_LogicalUnaryOp<"IsNan", SPIRV_Float, []> { ```mlir %2 = spirv.IsNan %0: f32 - %3 = spirv.IsNan %1: vector<4xi32> + %3 = spirv.IsNan %1: vector<4xf32> ``` }]; } diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index a877ad21734a2..1787e0a44f8fd 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -488,7 +488,12 @@ namespace mlir { void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns - patterns.add(typeConverter, patterns.getContext()); + patterns + .add, + CheckedElementwiseOpPattern, + CheckedElementwiseOpPattern>( + typeConverter, patterns.getContext()); // GLSL patterns patterns diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir new file mode 100644 index 0000000000000..3e5f592049e7f --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + + // CHECK-LABEL: @fpclassify + func.func @fpclassify(%x: f32, %v: vector<4xf32>) { + // CHECK: spirv.IsFinite %{{.*}} : f32 + %0 = math.isfinite %x : f32 + // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32> + %1 = math.isfinite %v : vector<4xf32> + + // CHECK: spirv.IsNan %{{.*}} : f32 + %2 = math.isnan %x : f32 + // CHECK: spirv.IsNan %{{.*}} : vector<4xf32> + %3 = math.isnan %v : vector<4xf32> + + // CHECK: spirv.IsInf %{{.*}} : f32 + %4 = math.isinf %x : f32 + // CHECK: spirv.IsInf %{{.*}} : vector<4xf32> + %5 = math.isinf %v : vector<4xf32> + + return + } + +} diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d6c34645f5746..58b828877e71d 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -32,6 +32,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto // ----- +//===----------------------------------------------------------------------===// +// spirv.IsFinite +//===----------------------------------------------------------------------===// + +func.func @isfinite_scalar(%arg0: f32) -> i1 { + // CHECK: spirv.IsFinite {{.*}} : f32 + %0 = spirv.IsFinite %arg0 : f32 + return %0 : i1 +} + +func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spirv.IsFinite {{.*}} : vector<2xf32> + %0 = spirv.IsFinite %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.IsInf //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index b2008719b021c..05cbddc048151 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce { %15 = spirv.IsNan %arg0 : f32 // CHECK: spirv.IsInf %16 = spirv.IsInf %arg1 : f32 + // CHECK: spirv.IsFinite + %17 = spirv.IsFinite %arg0 : f32 spirv.Return } }