Skip to content

[flang][CUDA] Apply intrinsic operator overrides #151018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025
Merged

Conversation

klausler
Copy link
Contributor

Fortran's intrinsic numeric and relational operators can be overridden with explicit interfaces so long as one or more of the dummy arguments have the DEVICE attribute. Semantics already allows this without complaint, but fails to replace the operations with the defined specific procedure calls when analyzing expressions.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:semantics labels Jul 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

Changes

Fortran's intrinsic numeric and relational operators can be overridden with explicit interfaces so long as one or more of the dummy arguments have the DEVICE attribute. Semantics already allows this without complaint, but fails to replace the operations with the defined specific procedure calls when analyzing expressions.


Full diff: https://github.com/llvm/llvm-project/pull/151018.diff

6 Files Affected:

  • (modified) flang/include/flang/Semantics/semantics.h (+9-1)
  • (modified) flang/lib/Semantics/check-cuda.cpp (+4-5)
  • (modified) flang/lib/Semantics/check-declarations.cpp (+3-2)
  • (modified) flang/lib/Semantics/expression.cpp (+68-7)
  • (added) flang/test/Semantics/bug1214.cuf (+49)
  • (modified) flang/test/Semantics/cuf11.cuf (+1-1)
diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h
index 0dbca51ee0dcf..b87e20f993663 100644
--- a/flang/include/flang/Semantics/semantics.h
+++ b/flang/include/flang/Semantics/semantics.h
@@ -123,6 +123,9 @@ class SemanticsContext {
   std::map<const Symbol *, SourceName> &moduleFileOutputRenamings() {
     return moduleFileOutputRenamings_;
   }
+  bool anyCUDAHostDeviceIntrinsicOperatorOverride() const {
+    return anyCUDAHostDeviceIntrinsicOperatorOverride_;
+  }
 
   SemanticsContext &set_location(
       const std::optional<parser::CharBlock> &___location) {
@@ -162,11 +165,15 @@ class SemanticsContext {
     warningsAreErrors_ = x;
     return *this;
   }
-
   SemanticsContext &set_debugModuleWriter(bool x) {
     debugModuleWriter_ = x;
     return *this;
   }
+  SemanticsContext &set_anyCUDAHostDeviceIntrinsicOperatorOverride(
+      bool yes = true) {
+    anyCUDAHostDeviceIntrinsicOperatorOverride_ = yes;
+    return *this;
+  }
 
   const DeclTypeSpec &MakeNumericType(TypeCategory, int kind = 0);
   const DeclTypeSpec &MakeLogicalType(int kind = 0);
@@ -352,6 +359,7 @@ class SemanticsContext {
   std::map<const Symbol *, SourceName> moduleFileOutputRenamings_;
   UnorderedSymbolSet isDefined_;
   std::list<ProgramTree> programTrees_;
+  bool anyCUDAHostDeviceIntrinsicOperatorOverride_{false};
 };
 
 class Semantics {
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index b01147606a99c..9b48432e049b9 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
   // legal.
   if (nbLhs == 0 && nbRhs > 1) {
     context_.Say(lhsLoc,
-        "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
+        "More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US);
   }
 
-  if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) &&
-      Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
+  if (evaluate::HasCUDADeviceAttrs(assign->lhs) &&
+      evaluate::HasCUDAImplicitTransfer(assign->rhs)) {
     if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 &&
-        GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 &&
-        GetNbOfCUDADeviceSymbols(assign->rhs) == 1) {
+        GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) {
       return; // This is a special case handled on the host.
     }
     context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index a2f2906af10b8..1ae8875f267bf 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -2081,12 +2081,13 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) {
 }
 
 static bool ConflictsWithIntrinsicOperator(
-    const GenericKind &kind, const Procedure &proc) {
+    const GenericKind &kind, const Procedure &proc, SemanticsContext &context) {
   if (!kind.IsIntrinsicOperator()) {
     return false;
   }
   const auto &arg0Data{std::get<DummyDataObject>(proc.dummyArguments[0].u)};
   if (CUDAHostDeviceDiffer(proc, arg0Data)) {
+    context.set_anyCUDAHostDeviceIntrinsicOperatorOverride();
     return false;
   }
   const auto &arg0TnS{arg0Data.type};
@@ -2167,7 +2168,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind,
     }
   } else if (!checkDefinedOperatorArgs(opName, specific, proc)) {
     return false; // error was reported
-  } else if (ConflictsWithIntrinsicOperator(kind, proc)) {
+  } else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) {
     msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US;
   }
   if (msg) {
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 14473724f0f40..d3c8a798c7f63 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -165,10 +165,17 @@ class ArgumentAnalyzer {
   bool CheckForNullPointer(const char *where = "as an operand here");
   bool CheckForAssumedRank(const char *where = "as an operand here");
 
+  bool AnyCUDADeviceData() const;
+  // Returns true if an interface has been defined for an intrinsic operator
+  // with one or more device operands.
+  bool HasDeviceDefinedIntrinsicOpOverride(const char *) const;
+  template <typename E> bool HasDeviceDefinedIntrinsicOpOverride(E opr) const {
+    return HasDeviceDefinedIntrinsicOpOverride(
+        context_.context().languageFeatures().GetNames(opr));
+  }
+
   // Find and return a user-defined operator or report an error.
   // The provided message is used if there is no such operator.
-  // If a definedOpSymbolPtr is provided, the caller must check
-  // for its accessibility.
   MaybeExpr TryDefinedOp(
       const char *, parser::MessageFixedText, bool isUserOp = false);
   template <typename E>
@@ -183,6 +190,8 @@ class ArgumentAnalyzer {
   void Dump(llvm::raw_ostream &);
 
 private:
+  bool HasDeviceDefinedIntrinsicOpOverride(
+      const std::vector<const char *> &) const;
   MaybeExpr TryDefinedOp(
       const std::vector<const char *> &, parser::MessageFixedText);
   MaybeExpr TryBoundOp(const Symbol &, int passIndex);
@@ -202,7 +211,7 @@ class ArgumentAnalyzer {
   void SayNoMatch(
       const std::string &, bool isAssignment = false, bool isAmbiguous = false);
   std::string TypeAsFortran(std::size_t);
-  bool AnyUntypedOrMissingOperand();
+  bool AnyUntypedOrMissingOperand() const;
 
   ExpressionAnalyzer &context_;
   ActualArguments actuals_;
@@ -4497,13 +4506,18 @@ void ArgumentAnalyzer::Analyze(
 bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr,
     const DynamicType &leftType, const DynamicType &rightType) const {
   CHECK(actuals_.size() == 2);
-  return semantics::IsIntrinsicRelational(
-      opr, leftType, GetRank(0), rightType, GetRank(1));
+  return !(context_.context().anyCUDAHostDeviceIntrinsicOperatorOverride() &&
+             HasDeviceDefinedIntrinsicOpOverride(opr)) &&
+      semantics::IsIntrinsicRelational(
+          opr, leftType, GetRank(0), rightType, GetRank(1));
 }
 
 bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const {
   std::optional<DynamicType> leftType{GetType(0)};
-  if (actuals_.size() == 1) {
+  if (context_.context().anyCUDAHostDeviceIntrinsicOperatorOverride() &&
+      HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) {
+    return false;
+  } else if (actuals_.size() == 1) {
     if (IsBOZLiteral(0)) {
       return opr == NumericOperator::Add; // unary '+'
     } else {
@@ -4617,6 +4631,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) {
   return true;
 }
 
+bool ArgumentAnalyzer::AnyCUDADeviceData() const {
+  for (const std::optional<ActualArgument> &arg : actuals_) {
+    if (arg) {
+      if (const Expr<SomeType> *expr{arg->UnwrapExpr()}) {
+        if (HasCUDADeviceAttrs(*expr)) {
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
+// Some operations can be defined with explicit non-type-bound interfaces
+// that would erroneously conflict with intrinsic operations in their
+// types and ranks but have one or more dummy arguments with the DEVICE
+// attribute.
+bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
+    const char *opr) const {
+  if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
+    std::string oprNameString{"operator("s + opr + ')'};
+    parser::CharBlock oprName{oprNameString};
+    parser::Messages buffer;
+    auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
+    const auto &scope{context_.context().FindScope(source_)};
+    if (Symbol * generic{scope.FindSymbol(oprName)}) {
+      parser::Name name{generic->name(), generic};
+      const Symbol *resultSymbol{nullptr};
+      if (context_.AnalyzeDefinedOp(
+              name, ActualArguments{actuals_}, resultSymbol)) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
+    const std::vector<const char *> &oprNames) const {
+  for (const char *opr : oprNames) {
+    if (HasDeviceDefinedIntrinsicOpOverride(opr)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 MaybeExpr ArgumentAnalyzer::TryDefinedOp(
     const char *opr, parser::MessageFixedText error, bool isUserOp) {
   if (AnyUntypedOrMissingOperand()) {
@@ -5135,7 +5196,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
   }
 }
 
-bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() {
+bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
   for (const auto &actual : actuals_) {
     if (!actual ||
         (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {
diff --git a/flang/test/Semantics/bug1214.cuf b/flang/test/Semantics/bug1214.cuf
new file mode 100644
index 0000000000000..114fad15ea500
--- /dev/null
+++ b/flang/test/Semantics/bug1214.cuf
@@ -0,0 +1,49 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+module overrides
+  type realResult
+    real a
+  end type
+  interface operator(*)
+    procedure :: multHostDevice, multDeviceHost
+  end interface
+  interface assignment(=)
+    procedure :: assignHostResult, assignDeviceResult
+  end interface
+ contains
+  elemental function multHostDevice(x, y) result(result)
+    real, intent(in) :: x
+    real, intent(in), device :: y
+    type(realResult) result
+    result%a = x * y
+  end
+  elemental function multDeviceHost(x, y) result(result)
+    real, intent(in), device :: x
+    real, intent(in) :: y
+    type(realResult) result
+    result%a = x * y
+  end
+  elemental subroutine assignHostResult(lhs, rhs)
+    real, intent(out) :: lhs
+    type(realResult), intent(in) :: rhs
+    lhs = rhs%a
+  end
+  elemental subroutine assignDeviceResult(lhs, rhs)
+    real, intent(out), device :: lhs
+    type(realResult), intent(in) :: rhs
+    lhs = rhs%a
+  end
+end
+
+program p
+  use overrides
+  real, device :: da, db
+  real :: ha, hb
+!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da))
+  db = 2. * da
+!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4))
+  db = da * 2.
+!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da))
+  ha = 2. * da
+!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4))
+  ha = da * 2.
+end
diff --git a/flang/test/Semantics/cuf11.cuf b/flang/test/Semantics/cuf11.cuf
index 554ac258e5510..1f5beb02aee45 100644
--- a/flang/test/Semantics/cuf11.cuf
+++ b/flang/test/Semantics/cuf11.cuf
@@ -16,7 +16,7 @@ subroutine sub1()
   real, device :: adev(10), bdev(10)
   real :: ahost(10)
 
-!ERROR: More than one reference to a CUDA object on the right hand side of the assigment
+!ERROR: More than one reference to a CUDA object on the right hand side of the assignment
   ahost = adev + bdev
 
   ahost = adev + adev

Copy link
Contributor

@akuhlens akuhlens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NICE!

Fortran's intrinsic numeric and relational operators can be
overridden with explicit interfaces so long as one or more of
the dummy arguments have the DEVICE attribute.  Semantics already
allows this without complaint, but fails to replace the operations
with the defined specific procedure calls when analyzing expressions.
Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for fixing the typo

@klausler klausler merged commit b01ab53 into llvm:main Jul 30, 2025
9 checks passed
@klausler klausler deleted the bug1214 branch July 30, 2025 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants