Skip to content

[NVPTX] Remove UnsafeFPMath uses #151479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ bool NVPTXDAGToDAGISel::allowFMA() const {
return TL->allowFMA(*MF, OptLevel);
}

bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
const NVPTXTargetLowering *TL = Subtarget->getTargetLowering();
return TL->allowUnsafeFPMath(*MF);
}

bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }

/// Select - Select instructions not customized! Used for
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool usePrecSqrtF32(const SDNode *N) const;
bool useF32FTZ() const;
bool allowFMA() const;
bool allowUnsafeFPMath() const;
bool doRsqrtOpt() const;

NVPTXScopes Scopes{};
Expand Down
27 changes: 4 additions & 23 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
if (UsePrecDivF32.getNumOccurrences() > 0)
return UsePrecDivF32;

// Otherwise, use div.approx if fast math is enabled
if (allowUnsafeFPMath(MF))
return NVPTX::DivPrecisionLevel::Approx;

const SDNodeFlags Flags = N.getFlags();
if (Flags.hasApproximateFuncs())
return NVPTX::DivPrecisionLevel::Approx;
Expand All @@ -142,10 +138,6 @@ bool NVPTXTargetLowering::usePrecSqrtF32(const MachineFunction &MF,
if (UsePrecSqrtF32.getNumOccurrences() > 0)
return UsePrecSqrtF32;

// Otherwise, use sqrt.approx if fast math is enabled
if (allowUnsafeFPMath(MF))
return false;

if (N) {
const SDNodeFlags Flags = N->getFlags();
if (Flags.hasApproximateFuncs())
Expand Down Expand Up @@ -2687,8 +2679,7 @@ static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
SDLoc(Op), Opcode, DAG);
}

static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
bool AllowUnsafeFPMath) {
static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG) {
// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
// the semantics of LLVM's frem.
Expand All @@ -2705,7 +2696,7 @@ static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
SDValue Sub = DAG.getNode(ISD::FSUB, DL, Ty, X, Mul,
Flags | SDNodeFlags::AllowContract);

if (AllowUnsafeFPMath || Flags.hasNoInfs())
if (Flags.hasNoInfs())
return Sub;

// If Y is infinite, return X
Expand Down Expand Up @@ -2845,7 +2836,7 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::CTLZ:
return lowerCTLZCTPOP(Op, DAG);
case ISD::FREM:
return lowerFREM(Op, DAG, allowUnsafeFPMath(DAG.getMachineFunction()));
return lowerFREM(Op, DAG);

default:
llvm_unreachable("Custom lowering not defined for operation");
Expand Down Expand Up @@ -4718,17 +4709,7 @@ bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
return true;

return allowUnsafeFPMath(MF);
}

bool NVPTXTargetLowering::allowUnsafeFPMath(const MachineFunction &MF) const {
// Honor TargetOptions flags that explicitly say unsafe math is okay.
if (MF.getTarget().Options.UnsafeFPMath)
return true;

// Allow unsafe math if unsafe-fp-math attribute explicitly says so.
const Function &F = MF.getFunction();
return F.getFnAttribute("unsafe-fp-math").getValueAsBool();
return false;
}

static bool isConstZero(const SDValue &Operand) {
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned combineRepeatedFPDivisors() const override { return 2; }

bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
bool allowUnsafeFPMath(const MachineFunction &MF) const;

bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT) const override {
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1181,9 +1181,8 @@ defm FMA_F64 : FMA<F64RT, allow_ftz = false>;
// sin/cos/tanh

class UnaryOpAllowsApproxFn<SDPatternOperator operator>
: PatFrag<(ops node:$A),
(operator node:$A), [{
return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
: PatFrag<(ops node:$A), (operator node:$A), [{
return N->getFlags().hasApproximateFuncs();
}]>;

def SIN_APPROX_f32 :
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx71 --enable-unsafe-fp-math | FileCheck --check-prefixes=CHECK %s
; RUN: %if ptxas-11.8 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx71 --enable-unsafe-fp-math | %ptxas-verify -arch=sm_80 %}
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | FileCheck --check-prefixes=CHECK %s
; RUN: %if ptxas-11.8 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | %ptxas-verify -arch=sm_80 %}

target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"

Expand All @@ -22,7 +22,7 @@ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
; CHECK-NEXT: cvt.rn.bf16x2.f32 %r5, %r4, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
; CHECK-NEXT: ret;
%r = call <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a)
%r = call afn <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a)
ret <2 x bfloat> %r
}

Expand All @@ -41,7 +41,7 @@ define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
; CHECK-NEXT: cvt.rn.bf16x2.f32 %r5, %r4, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
; CHECK-NEXT: ret;
%r = call <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a)
%r = call afn <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a)
ret <2 x bfloat> %r
}

9 changes: 4 additions & 5 deletions llvm/test/CodeGen/NVPTX/f16-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,8 @@ define half @test_sqrt(half %a) #0 {
; CHECK: cvt.rn.f16.f32 [[R:%rs[0-9]+]], [[RF]];
; CHECK: st.param.b16 [func_retval0], [[R]];
; CHECK: ret;
define half @test_sin(half %a) #0 #1 {
%r = call half @llvm.sin.f16(half %a)
define half @test_sin(half %a) #0 {
%r = call afn half @llvm.sin.f16(half %a)
ret half %r
}

Expand All @@ -900,8 +900,8 @@ define half @test_sin(half %a) #0 #1 {
; CHECK: cvt.rn.f16.f32 [[R:%rs[0-9]+]], [[RF]];
; CHECK: st.param.b16 [func_retval0], [[R]];
; CHECK: ret;
define half @test_cos(half %a) #0 #1 {
%r = call half @llvm.cos.f16(half %a)
define half @test_cos(half %a) #0 {
%r = call afn half @llvm.cos.f16(half %a)
ret half %r
}

Expand Down Expand Up @@ -1183,4 +1183,3 @@ define <2 x half> @test_neg_f16x2(<2 x half> noundef %arg) #0 {
}

attributes #0 = { nounwind }
attributes #1 = { "unsafe-fp-math" = "true" }
9 changes: 4 additions & 5 deletions llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ define <2 x half> @test_sqrt(<2 x half> %a) #0 {
; ret <2 x half> %r
;}

define <2 x half> @test_sin(<2 x half> %a) #0 #1 {
define <2 x half> @test_sin(<2 x half> %a) #0 {
; CHECK-LABEL: test_sin(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<5>;
Expand All @@ -1692,11 +1692,11 @@ define <2 x half> @test_sin(<2 x half> %a) #0 #1 {
; CHECK-NEXT: mov.b32 %r6, {%rs4, %rs3};
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
%r = call <2 x half> @llvm.sin.f16(<2 x half> %a)
%r = call afn <2 x half> @llvm.sin.f16(<2 x half> %a)
ret <2 x half> %r
}

define <2 x half> @test_cos(<2 x half> %a) #0 #1 {
define <2 x half> @test_cos(<2 x half> %a) #0 {
; CHECK-LABEL: test_cos(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<5>;
Expand All @@ -1714,7 +1714,7 @@ define <2 x half> @test_cos(<2 x half> %a) #0 #1 {
; CHECK-NEXT: mov.b32 %r6, {%rs4, %rs3};
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
; CHECK-NEXT: ret;
%r = call <2 x half> @llvm.cos.f16(<2 x half> %a)
%r = call afn <2 x half> @llvm.cos.f16(<2 x half> %a)
ret <2 x half> %r
}

Expand Down Expand Up @@ -2330,4 +2330,3 @@ define void @test_store_2xhalf(ptr %p1, ptr %p2, <2 x half> %v) {


attributes #0 = { nounwind }
attributes #1 = { "unsafe-fp-math" = "true" }
9 changes: 4 additions & 5 deletions llvm/test/CodeGen/NVPTX/f32x2-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ define <2 x float> @test_sqrt(<2 x float> %a) #0 {
; ret <2 x float> %r
;}

define <2 x float> @test_sin(<2 x float> %a) #0 #1 {
define <2 x float> @test_sin(<2 x float> %a) #0 {
; CHECK-LABEL: test_sin(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
Expand All @@ -1640,11 +1640,11 @@ define <2 x float> @test_sin(<2 x float> %a) #0 #1 {
; CHECK-NEXT: sin.approx.f32 %r4, %r1;
; CHECK-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
; CHECK-NEXT: ret;
%r = call <2 x float> @llvm.sin(<2 x float> %a)
%r = call afn <2 x float> @llvm.sin(<2 x float> %a)
ret <2 x float> %r
}

define <2 x float> @test_cos(<2 x float> %a) #0 #1 {
define <2 x float> @test_cos(<2 x float> %a) #0 {
; CHECK-LABEL: test_cos(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
Expand All @@ -1657,7 +1657,7 @@ define <2 x float> @test_cos(<2 x float> %a) #0 #1 {
; CHECK-NEXT: cos.approx.f32 %r4, %r1;
; CHECK-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
; CHECK-NEXT: ret;
%r = call <2 x float> @llvm.cos(<2 x float> %a)
%r = call afn <2 x float> @llvm.cos(<2 x float> %a)
ret <2 x float> %r
}

Expand Down Expand Up @@ -2146,5 +2146,4 @@ define void @test_trunc_to_v2f16(<2 x float> %a, ptr %p) {


attributes #0 = { nounwind }
attributes #1 = { "unsafe-fp-math" = "true" }
attributes #2 = { "denormal-fp-math"="preserve-sign" }
Loading